Skip to content

vllm.model_executor.layers.quantization.utils.marlin_utils_fp4

_nvfp4_compute_scale_factor

_nvfp4_compute_scale_factor(marlin_scales: Tensor) -> float

Compute the power-of-2 scale_factor needed so that all non-zero values in marlin_scales * 2^7 are >= 2 after rescaling. Returns a Python float (power of 2, >= 1.0).

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
    """Compute the power-of-2 scale_factor needed so that all non-zero
    values in marlin_scales * 2^7 are >= 2 after rescaling.
    Returns a Python float (power of 2, >= 1.0)."""
    ws_float = marlin_scales.float() * (2**7)
    nonzero_mask = ws_float > 0
    if nonzero_mask.any():
        min_val = ws_float[nonzero_mask].min()
        if min_val < 2:
            sf = (2 / min_val).log2().ceil().exp2()
            return sf.item()
    return 1.0

nvfp4_marlin_process_scales

nvfp4_marlin_process_scales(
    marlin_scales: Tensor, scale_factor: float | None = None
) -> tuple[Tensor, float]

Process NVFP4 weight scales into the special S0E5M3 format for Marlin.

Parameters:

Name Type Description Default
marlin_scales Tensor

Weight scales tensor in half precision, already permuted for the Marlin kernel layout.

required
scale_factor float | None

Optional power-of-2 rescaling factor. If None, the factor is computed automatically so that every non-zero scale satisfies scale * 2^7 >= 2 (i.e., the MSB of the S0E5M3 representation is always 1). When provided (e.g., for MoE layers where all experts must share the same factor), the given value is used directly. The caller is responsible for dividing global_scale by the returned scale_factor to preserve numerical correctness.

None

Returns:

Type Description
tuple[Tensor, float]

A tuple of (processed_scales, scale_factor).

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
def nvfp4_marlin_process_scales(
    marlin_scales: torch.Tensor,
    scale_factor: float | None = None,
) -> tuple[torch.Tensor, float]:
    """Process NVFP4 weight scales into the special S0E5M3 format for Marlin.

    Args:
        marlin_scales: Weight scales tensor in half precision, already
            permuted for the Marlin kernel layout.
        scale_factor: Optional power-of-2 rescaling factor. If None, the
            factor is computed automatically so that every non-zero scale
            satisfies ``scale * 2^7 >= 2`` (i.e., the MSB of the S0E5M3
            representation is always 1). When provided (e.g., for MoE
            layers where all experts must share the same factor), the
            given value is used directly. The caller is responsible for
            dividing ``global_scale`` by the returned ``scale_factor`` to
            preserve numerical correctness.

    Returns:
        A tuple of (processed_scales, scale_factor).
    """
    if not (marlin_scales >= 0).all():
        logger.warning_once(
            "NVFP4 Marlin assumes the scales to be >=0, but has encountered "
            "negative scales. Accuracy will likely be degraded. This is "
            "because it changes the scales from FP8-S1E4M3 to a special "
            "FP8-S0E5M3 format to speedup the dequantization."
        )

    # convert to half first, we would convert to fp8 later
    marlin_scales = marlin_scales.to(torch.half)

    # fit the layout of fp8 dequantization
    marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
        marlin_scales.size(0), -1
    )

    # We assume that weight_scale (FP8-S1E4M3) is always greater
    # than or equal to 0. So we can convert
    # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
    # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
    # when weight_scale > 0. This allows us to have an exponent bias
    # closer to zero after dequantization.

    # Rescale weight_scale so that all non-zero values have MSB=1
    # after multiplying by 2^7 (i.e., weight_scale * 2^7 >= 2).
    # This is needed for models whose E4M3 scales were not normalized
    # to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
    # The caller must compensate by dividing global_scale by scale_factor.
    if scale_factor is None:
        scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
    if scale_factor > 1.0:
        marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)

    marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
    marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
    marlin_scales = marlin_scales[:, 1::2].contiguous()

    return marlin_scales, scale_factor