Skip to content

vllm.utils.deep_gemm

Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import only these wrappers.

__all__ module-attribute

__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_used",
]

_dg module-attribute

_dg = import_module('deep_gemm')

_fp8_gemm_nt_impl module-attribute

_fp8_gemm_nt_impl: Callable[..., Any] | None = None

_grouped_impl module-attribute

_grouped_impl: Callable[..., Any] | None = None

_grouped_masked_impl module-attribute

_grouped_masked_impl: Callable[..., Any] | None = None

_math_mod module-attribute

_math_mod = import_module('deep_gemm.utils.math')

_per_block_cast_impl module-attribute

_per_block_cast_impl: Callable[..., Any] | None = getattr(
    _math_mod, "per_block_cast_to_fp8", None
)

_missing

_missing(*_: Any, **__: Any) -> NoReturn

Placeholder for unavailable DeepGEMM backend.

Source code in vllm/utils/deep_gemm.py
def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")

_resolve_symbol

_resolve_symbol(
    module, new: str, old: str
) -> Callable[..., Any] | None

Return the new symbol if it exists, otherwise the old one.

Source code in vllm/utils/deep_gemm.py
def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        return getattr(module, old)
    return None

calc_diff

calc_diff(x: Tensor, y: Tensor)

Return a global difference metric for unit tests.

DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.

Source code in vllm/utils/deep_gemm.py
def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim

fp8_gemm_nt

fp8_gemm_nt(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def fp8_gemm_nt(*args, **kwargs):
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(*args, **kwargs)

fp8_m_grouped_gemm_nt_masked

fp8_m_grouped_gemm_nt_masked(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(*args, **kwargs)

is_blackwell_deep_gemm_used cached

is_blackwell_deep_gemm_used() -> bool

Return True if vLLM is configured to use DeepGEMM on a Blackwell-class GPU.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM on a
    Blackwell-class GPU.
    """

    if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
            and _per_block_cast_impl is not None):
        return False

    return cuda_get_device_properties(0, ("major", ))[0] == 10

m_grouped_fp8_gemm_nt_contiguous

m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args, **kwargs)

per_block_cast_to_fp8

per_block_cast_to_fp8(x, *args, **kwargs)
Source code in vllm/utils/deep_gemm.py
def per_block_cast_to_fp8(x, *args, **kwargs):
    if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
        return _per_block_cast_impl(x)
    # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
    from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
    return _pbcf(x, *args, **kwargs)