Skip to content

[BUG]: CuTile doesn't support graph mode #80

@LironKesem

Description

@LironKesem

Version

cuda-tile 1.1.0

Version

13.2

Which installation method(s) does this occur on?

No response

Describe the bug.

cuda-tile                                1.1.0
nvidia-cuda-tileiras                     13.1.80

when running this

Minimum reproducible example

import torch                                                                                   
import cuda.tile as ct                                                                         
                 
ConstInt = ct.Constant[int]
TILE_M=32
TILE_N=32
TILE_K=32
@ct.kernel(num_ctas=ct.ByTarget(sm_121=4))
def matmul_kernel(A, B, As, Bs, C,
                  M: ConstInt, N: ConstInt, K: ConstInt,
                  TILE_M: ConstInt, TILE_N: ConstInt, TILE_K: ConstInt):
                  
    bid_m = ct.bid(0) // ct.cdiv(N, TILE_N)
    bid_n = ct.bid(0) % ct.cdiv(N, TILE_N)
    num_tiles_k = ct.cdiv(K, TILE_K)
  
    acc = ct.zeros((TILE_M, TILE_N), dtype=ct.float32)

    for k_idx in range(num_tiles_k):
        a_tile = ct.load(A, index=(bid_m, k_idx), shape=(TILE_M, TILE_K))
        b_tile = ct.load(B, index=(k_idx, bid_n), shape=(TILE_K, TILE_N))
        a_scale = ct.load(As, index=(bid_m, k_idx), shape=(TILE_M, 1))
        b_scale = ct.load(Bs, index=(k_idx, bid_n), shape=(1, 1))

        dot_prod = ct.mma(a_tile, b_tile, ct.zeros((TILE_M, TILE_N), dtype=ct.float32))
        acc += dot_prod * (a_scale * b_scale)
    
        ct.store(C, index=(bid_m, bid_n), tile=ct.astype(acc, C.dtype))


def cutile_mm_wrapper(out, A, B, As, Bs):
    M, K = A.shape
    K_check, N = B.shape
    assert K == K_check
    TILE_M, TILE_N, TILE_K = 128, 128, 128

    grid_m = ct.cdiv(M, TILE_M)
    grid_n = ct.cdiv(N, TILE_N)
    grid = (grid_m * grid_n, 1, 1)

    stream_ptr = torch.cuda.current_stream().cuda_stream
    ct.launch(stream_ptr, grid, matmul_kernel,(A, B, As, Bs, out, M, N, K, TILE_M, TILE_N, TILE_K))
    return out

def test_torch_compile():
    torch.manual_seed(0)
    device = "cuda"
    M, N, K = 256, 256, 256
    TILE_M, TILE_N, TILE_K = 128, 128, 128
    
    A_fp8 = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
    B_fp8 = torch.randn(K, N, device=device).to(torch.float8_e4m3fn).t().contiguous()

    m_tiles = ct.cdiv(M, TILE_M)
    n_tiles = ct.cdiv(N, TILE_N)
    k_tiles = ct.cdiv(K, TILE_K)

    As = torch.ones(m_tiles, k_tiles, device=device, dtype=torch.float32).t().contiguous()
    Bs = torch.ones(k_tiles, n_tiles, device=device, dtype=torch.float32).t().contiguous()

    out_dtype = torch.bfloat16
    out_no_compile = torch.empty((M, N), dtype=out_dtype, device=device)
    cutile_mm_wrapper(out_no_compile, A_fp8, B_fp8, As, Bs)
    compiled_fn = torch.compile(cutile_mm_wrapper, fullgraph=True)
    out_compile = torch.empty((M, N), dtype=out_dtype, device=device)
    compiled_fn(out_compile, A_fp8, B_fp8, As, Bs)

    assert torch.allclose(out_no_compile, out_compile, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
    test_torch_compile()




this is the minimal code:

import torch                                                                                   
import cuda.tile as ct                                                                         
                 
ConstInt = ct.Constant[int]
TILE_M=32
TILE_N=32
TILE_K=32
@ct.kernel(num_ctas=ct.ByTarget(sm_121=4))
def matmul_kernel(A, B, As, Bs, C,
                  M: ConstInt, N: ConstInt, K: ConstInt,
                  TILE_M: ConstInt, TILE_N: ConstInt, TILE_K: ConstInt):
                  
    bid_m = ct.bid(0) // ct.cdiv(N, TILE_N)
    bid_n = ct.bid(0) % ct.cdiv(N, TILE_N)
    num_tiles_k = ct.cdiv(K, TILE_K)
  
    acc = ct.zeros((TILE_M, TILE_N), dtype=ct.float32)

    for k_idx in range(num_tiles_k):
        a_tile = ct.load(A, index=(bid_m, k_idx), shape=(TILE_M, TILE_K))
        b_tile = ct.load(B, index=(k_idx, bid_n), shape=(TILE_K, TILE_N))
        a_scale = ct.load(As, index=(bid_m, k_idx), shape=(TILE_M, 1))
        b_scale = ct.load(Bs, index=(k_idx, bid_n), shape=(1, 1))

        dot_prod = ct.mma(a_tile, b_tile, ct.zeros((TILE_M, TILE_N), dtype=ct.float32))
        acc += dot_prod * (a_scale * b_scale)
    
        ct.store(C, index=(bid_m, bid_n), tile=ct.astype(acc, C.dtype))


def cutile_mm_wrapper(out, A, B, As, Bs):
    M, K = A.shape
    K_check, N = B.shape
    assert K == K_check
    TILE_M, TILE_N, TILE_K = 128, 128, 128

    grid_m = ct.cdiv(M, TILE_M)
    grid_n = ct.cdiv(N, TILE_N)
    grid = (grid_m * grid_n, 1, 1)

    stream_ptr = torch.cuda.current_stream().cuda_stream
    ct.launch(stream_ptr, grid, matmul_kernel,(A, B, As, Bs, out, M, N, K, TILE_M, TILE_N, TILE_K))
    return out

def test_torch_compile():
    torch.manual_seed(0)
    device = "cuda"
    M, N, K = 256, 256, 256
    TILE_M, TILE_N, TILE_K = 128, 128, 128
    
    A_fp8 = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
    B_fp8 = torch.randn(K, N, device=device).to(torch.float8_e4m3fn).t().contiguous()

    m_tiles = ct.cdiv(M, TILE_M)
    n_tiles = ct.cdiv(N, TILE_N)
    k_tiles = ct.cdiv(K, TILE_K)

    As = torch.ones(m_tiles, k_tiles, device=device, dtype=torch.float32).t().contiguous()
    Bs = torch.ones(k_tiles, n_tiles, device=device, dtype=torch.float32).t().contiguous()

    out_dtype = torch.bfloat16
    out_no_compile = torch.empty((M, N), dtype=out_dtype, device=device)
    cutile_mm_wrapper(out_no_compile, A_fp8, B_fp8, As, Bs)
    compiled_fn = torch.compile(cutile_mm_wrapper, fullgraph=True)
    out_compile = torch.empty((M, N), dtype=out_dtype, device=device)
    compiled_fn(out_compile, A_fp8, B_fp8, As, Bs)

    assert torch.allclose(out_no_compile, out_compile, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
    test_torch_compile()

Relevant log output

python -m pytest -s tests/kernels/quantization/test_failing_graph_mode.py 
DEBUG 04-03 00:57:26 [plugins/__init__.py:36] No plugins for group vllm.platform_plugins found.
DEBUG 04-03 00:57:26 [platforms/__init__.py:37] Checking if TPU platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:56] TPU platform is not available because: No module named 'libtpu'
DEBUG 04-03 00:57:26 [platforms/__init__.py:62] Checking if CUDA platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:85] Confirmed CUDA platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:113] Checking if ROCm platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:127] ROCm platform is not available because: No module named 'amdsmi'
DEBUG 04-03 00:57:26 [platforms/__init__.py:134] Checking if XPU platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:165] Checking if CPU platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:62] Checking if CUDA platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:85] Confirmed CUDA platform is available.
DEBUG 04-03 00:57:26 [platforms/__init__.py:247] Automatically detected platform cuda.
============================== test session starts ==============================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/lkesem/vllm
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 1 item                                                                

tests/kernels/quantization/test_failing_graph_mode.py F

=================================== FAILURES ====================================
______________________________ test_torch_compile _______________________________

    def test_torch_compile():
        torch.manual_seed(0)
        device = "cuda"
        M, N, K = 256, 256, 256
        TILE_M, TILE_N, TILE_K = 128, 128, 128
    
        A_fp8 = torch.randn(M, K, device=device).to(torch.float8_e4m3fn)
        B_fp8 = torch.randn(K, N, device=device).to(torch.float8_e4m3fn).t().contiguous()
    
        m_tiles = ct.cdiv(M, TILE_M)
        n_tiles = ct.cdiv(N, TILE_N)
        k_tiles = ct.cdiv(K, TILE_K)
    
        As = torch.ones(m_tiles, k_tiles, device=device, dtype=torch.float32).t().contiguous()
        Bs = torch.ones(k_tiles, n_tiles, device=device, dtype=torch.float32).t().contiguous()
    
        out_dtype = torch.bfloat16
        out_no_compile = torch.empty((M, N), dtype=out_dtype, device=device)
        cutile_mm_wrapper(out_no_compile, A_fp8, B_fp8, As, Bs)
        compiled_fn = torch.compile(cutile_mm_wrapper, fullgraph=True)
        out_compile = torch.empty((M, N), dtype=out_dtype, device=device)
>       compiled_fn(out_compile, A_fp8, B_fp8, As, Bs)

tests/kernels/quantization/test_failing_graph_mode.py:66: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (tensor([[-2.3865e-09, -1.7043e-07,  2.7865e-06,  ...,  8.6914e-02,
         -2.8750e+00,  5.8620e-13],
        [-6.95..._e4m3fn), tensor([[1., 1.],
        [1., 1.]], device='cuda:0'), tensor([[1., 1.],
        [1., 1.]], device='cuda:0'))
kwargs = {}, prior = None
prior_eval_frame_override = <_EvalFrameOverride.NONE: 0>, tracing_context = None
cleanups = [<function nothing at 0xf4ca6233b2e0>]
prior_skip_guard_eval_unsafe = False, prior_error_on_graph_break = None
saved_dynamic_layer_stack_depth = 0
cleanup = <function nothing at 0xf4ca6233b2e0>

    @functools.wraps(fn)
    def compile_wrapper(*args: Any, **kwargs: Any) -> Any:
        # NB: function calls here could change global state (e.g. random state)
        # and that can result in different behavior between eager and compiled!
        # In particular, we don't have control over internal functions like justknobs_check
        # called in _maybe_set_eval_frame.
        # Unlike in eval_frame_cpp.cpp/convert_frame.py, we don't attempt to restore global state
        # due to additional overhead costs.
        prior = set_eval_frame(None)
        prior_eval_frame_override: _EvalFrameOverride | None = None
        if self.fullgraph:
            prior_eval_frame_override = set_eval_frame_override(
                _get_eval_frame_override()
            )
        try:
            # We shouldn't compile inside kernel invocation.
            if tracing_context := torch._guards.TracingContext.try_get():
                if (
                    tracing_context.fake_mode is not None
                    and tracing_context.fake_mode.in_kernel_invocation
                ):
                    return fn(*args, **kwargs)
            # Skip nested compile during export (but not HOP internal compile)
            # Only skip if there's an active TracingContext (nested), not for top-level export
            if (
                torch.compiler.is_exporting()
                and not config.force_compile_during_fx_trace
            ):
                from torch._higher_order_ops.utils import _in_hop_compile
    
                if not _in_hop_compile():
                    if torch._guards.TracingContext.try_get() is not None:
                        return fn(*args, **kwargs)
            # Skip nested compile - just inline the function
            if (
                is_fx_symbolic_tracing()
                and not config.force_compile_during_fx_trace
            ):
                if config.error_on_nested_fx_trace:
                    raise RuntimeError(
                        "Detected that you are using FX to symbolically trace "
                        "a dynamo-optimized function. This is not supported at the moment."
                    )
                else:
                    return fn(*args, **kwargs)
    
            if is_jit_tracing():
                raise RuntimeError(
                    "Detected that you are using FX to torch.jit.trace "
                    "a dynamo-optimized function. This is not supported at the moment."
                )
    
            cleanups = [enter() for enter in self.enter_exit_hooks]
            prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
                _is_skip_guard_eval_unsafe_stance()
            )
            prior_error_on_graph_break = None
            if not self.fullgraph and self.error_on_graph_break is not None:
                prior_error_on_graph_break = _get_error_on_graph_break()
                _set_error_on_graph_break(self.error_on_graph_break)
    
            # Ensure that if an assertion occurs after graph pushes
            # something onto the DynamicLayerStack then we pop it off (the
            # constructed graph code isn't guarded with try/finally).
            #
            # This used to be a context but putting a `with` here is a noticeable
            # perf regression (#126293)
            saved_dynamic_layer_stack_depth = (
                torch._C._functorch.get_dynamic_layer_stack_depth()
            )
    
            _maybe_set_eval_frame(_callback_from_stance(callback))
    
            try:
                return fn(*args, **kwargs)
            except (Unsupported, UncapturedHigherOrderOpError, UserError) as e:
                if config.verbose:
                    raise
                # strip internal tracebacks from causes
                cur_exn: BaseException = e
                while cur_exn.__cause__ is not None:
                    cur_exn.__cause__.with_traceback(None)
                    cur_exn = cur_exn.__cause__
    
>               raise e.with_traceback(None) from e.__cause__  # User compiler error
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E               torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
E                 Explanation: Dynamo does not know how to trace the builtin `cuda.tile._cext.launch.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
E                 Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
E                 Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
E               
E                 Developer debug context: module: cuda.tile._cext, qualname: launch, skip reason: cannot determine source file for cuda.tile._cext (likely a C extension or builtin)
E               
E                For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html
E               
E               from user code:
E                  File "/home/lkesem/vllm/tests/kernels/quantization/test_failing_graph_mode.py", line 42, in cutile_mm_wrapper
E                   ct.launch(stream_ptr, grid, matmul_kernel,(A, B, As, Bs, out, M, N, K, TILE_M, TILE_N, TILE_K))
E               
E               Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

../.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1046: Unsupported
=============================== warnings summary ================================
../.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:1984
  /home/lkesem/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:1984: FutureWarning: torch._dynamo.allow_in_graph is deprecated and will be removed in a future version. Use torch._dynamo.nonstrict_trace instead.
    @torch._dynamo.allow_in_graph

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

tests/kernels/quantization/test_failing_graph_mode.py::test_torch_compile
  /home/lkesem/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:2294: UserWarning: Dynamo does not know how to trace the builtin `cuda.tile._cext.launch.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
  If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
  If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
    torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))

tests/kernels/quantization/test_failing_graph_mode.py: 14 warnings
  /home/lkesem/.venv/lib/python3.12/site-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================ short test summary info ============================
FAILED tests/kernels/quantization/test_failing_graph_mode.py::test_torch_compile - torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
======================== 1 failed, 18 warnings in 1.42s =========================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
(lkesem) root@thinkstationpgx-1aec:~/lkesem/vllm$

Full env printout

python collect_env.py 
Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 24.04.4 LTS (aarch64)
GCC version                  : (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version                : Could not collect
CMake version                : version 4.2.3
Libc version                 : glibc-2.39

==============================
       PyTorch Info
==============================
PyTorch version              : 2.12.0.dev20260325+cu130
Is debug build               : False
CUDA used to build PyTorch   : 13.0
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-6.17.0-1008-nvidia-aarch64-with-glibc2.39

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 13.2.51
CUDA_MODULE_LOADING set to   : 
GPU models and configuration : GPU 0: NVIDIA GB10
Nvidia driver version        : 580.126.09
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            aarch64
CPU op-mode(s):                          64-bit
Byte Order:                              Little Endian
CPU(s):                                  20
On-line CPU(s) list:                     0-19
Vendor ID:                               ARM
Model name:                              Cortex-X925
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
Frequency boost:                         disabled
CPU(s) scaling MHz:                      100%
CPU max MHz:                             3900.0000
CPU min MHz:                             1378.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
Model name:                              Cortex-A725
Model:                                   1
Thread(s) per core:                      1
Core(s) per socket:                      10
Socket(s):                               1
Stepping:                                r0p1
CPU(s) scaling MHz:                      100%
CPU max MHz:                             2808.0000
CPU min MHz:                             338.0000
BogoMIPS:                                2000.00
Flags:                                   fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti ecv afp wfxt
L1d cache:                               1.3 MiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                25 MiB (20 instances)
L3 cache:                                24 MiB (2 instances)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-19
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; CSV2, BHB
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.6.7
[pip3] numpy==2.2.6
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-crt==13.2.51
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvcc==13.2.51
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cuda-tileiras==13.1.80
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-cufile==1.15.1.6
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-cutlass-dsl==4.4.1
[pip3] nvidia-cutlass-dsl-libs-base==4.4.1
[pip3] nvidia-ml-py==13.590.48
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvshmem-cu13==3.4.5
[pip3] nvidia-nvtx==13.0.85
[pip3] nvidia-nvvm==13.2.51
[pip3] pyzmq==27.1.0
[pip3] torch==2.12.0.dev20260325+cu130
[pip3] torch_c_dlpack_ext==0.1.5
[pip3] torchaudio==2.11.0.dev20260401+cu130
[pip3] torchvision==0.27.0.dev20260401+cu130
[pip3] transformers==4.57.6
[pip3] triton==3.6.0+git9844da95
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.18.2rc1.dev28+gdcf8a2477.d20260402 (git sha: dcf8a2477, date: 20260402)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
        GPU0    NIC0    NIC1    NIC2    NIC3    CPU Affinity    NUMA AffinityGPU NUMA ID
GPU0     X      NODE    NODE    NODE    NODE    0-19    0               N/A
NIC0    NODE     X      PIX     NODE    NODE
NIC1    NODE    PIX      X      NODE    NODE
NIC2    NODE    NODE    NODE     X      PIX
NIC3    NODE    NODE    NODE    PIX      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: rocep1s0f0
  NIC1: rocep1s0f1
  NIC2: roceP2p1s0f0
  NIC3: roceP2p1s0f1

==============================
     Environment Variables
==============================
LD_LIBRARY_PATH=/usr/local/cuda-13.2/compat:/usr/local/cuda/lib64:/home/lkesem/vllm/.venv/lib/python3.12/site-packages/torch/lib:
VLLM_LOGGING_LEVEL=DEBUG
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1

Other/Misc.

No response

Contributing Guidelines

  • I agree to follow cuTile Python's contributing guidelines
  • I have searched the open bugs and have found no duplicates for this bug report

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions