Skip to content

vllm.model_executor.layers.quantization.mxfp8

Online MXFP8 (microscaling FP8, block-32) quantization config and methods.

Mxfp8Config

Bases: Fp8Config

Config class for online MXFP8 MoE quantization.

Source code in vllm/model_executor/layers/quantization/mxfp8.py
class Mxfp8Config(Fp8Config):
    """Config class for online MXFP8 MoE quantization."""

    def __init__(
        self,
        activation_scheme: str = "dynamic",
        ignored_layers: list[str] | None = None,
    ) -> None:
        if activation_scheme != "dynamic":
            raise ValueError("mxfp8 only supports dynamic activation scheme.")
        super().__init__(
            is_checkpoint_fp8_serialized=False,
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
            weight_block_size=None,
        )

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "mxfp8"

    @classmethod
    def get_min_capability(cls) -> int:
        return 100

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
        activation_scheme = cls.get_from_keys_or(
            config, ["activation_scheme"], "dynamic"
        )
        ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
        if not ignored_layers:
            ignored_layers = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None
            )
        return cls(
            activation_scheme=activation_scheme,
            ignored_layers=ignored_layers,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> "QuantizeMethodBase | None":
        if isinstance(layer, LinearBase):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
                skip_with_substr=True,
            ):
                return UnquantizedLinearMethod()
            return Mxfp8OnlineLinearMethod(self)
        elif isinstance(layer, FusedMoE):
            if is_layer_skipped(
                prefix=prefix,
                ignored_layers=self.ignored_layers,
                fused_mapping=self.packed_modules_mapping,
                skip_with_substr=True,
            ):
                return UnquantizedFusedMoEMethod(layer.moe_config)
            return Mxfp8OnlineMoEMethod(self, layer)
        elif isinstance(layer, Attention):
            return Fp8KVCacheMethod(self)
        return None

Mxfp8OnlineLinearMethod

Bases: Fp8OnlineLinearMethod

Online MXFP8 linear method. Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling FP8 with block-32 scales) during weight loading.

Parameters:

Name Type Description Default
quant_config Mxfp8Config

The MXFP8 quantization config.

required
Source code in vllm/model_executor/layers/quantization/mxfp8.py
class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
    """Online MXFP8 linear method.
    Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling
    FP8 with block-32 scales) during weight loading.

    Args:
        quant_config: The MXFP8 quantization config.
    """

    uses_meta_device: bool = True

    def __init__(self, quant_config: "Mxfp8Config"):
        self.quant_config = quant_config
        self.out_dtype = torch.get_default_dtype()
        self.mxfp8_linear = Mxfp8LinearOp(self._select_backend())
        logger.info_once(
            "Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value
        )

    @staticmethod
    def _select_backend() -> Mxfp8LinearBackend:
        try:
            from vllm.utils import flashinfer as fi

            _ = fi.mm_mxfp8
            return Mxfp8LinearBackend.FLASHINFER_CUTLASS
        except Exception:
            logger.warning(
                "FlashInfer mm_mxfp8 not available, "
                "falling back to MXFP8 emulation backend."
            )
            return Mxfp8LinearBackend.EMULATION

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
            raise ValueError(
                f"MXFP8 requires input_size_per_partition "
                f"({input_size_per_partition}) to be divisible by "
                f"{MXFP8_BLOCK_SIZE}."
            )

        super().create_weights(
            layer,
            input_size_per_partition,
            output_partition_sizes,
            input_size,
            output_size,
            params_dtype,
            **extra_weight_attrs,
        )

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        if layer.weight.device == torch.device("meta"):
            weight = ModelWeightParameter(
                data=torch.empty_like(layer.weight, device=layer._load_device),
                input_dim=1,
                output_dim=0,
                weight_loader=layer.weight.weight_loader,
            )
            _copy_missing_attrs(layer.weight, weight)
            layer.register_parameter("weight", weight)
            initialize_single_dummy_weight(layer.weight)

        weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())

        if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
            N, K = layer.weight.shape[0], layer.weight.shape[1]
            weight_scale = swizzle_mxfp8_scale(weight_scale, N, K)

        layer.input_scale = None
        replace_parameter(layer, "weight", weight_fp8.data)
        replace_parameter(layer, "weight_scale", weight_scale.data)

        layer._already_called_process_weights_after_loading = True

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.mxfp8_linear.apply(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            out_dtype=self.out_dtype,
            bias=bias,
        )

Mxfp8OnlineMoEMethod

Bases: Fp8OnlineMoEMethod

MoE method for online MXFP8 (block) quantization.

Source code in vllm/model_executor/layers/quantization/mxfp8.py
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
    """MoE method for online MXFP8 (block) quantization."""

    uses_meta_device: bool = True

    def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
        FusedMoEMethodBase.__init__(self, layer.moe_config)
        self.quant_config = quant_config
        assert not quant_config.is_checkpoint_fp8_serialized
        assert quant_config.activation_scheme == "dynamic"

        self.weight_block_size = [1, MXFP8_BLOCK_SIZE]
        self.block_quant = True
        self.weight_scale_name = "weight_scale"

        self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)

    def create_weights(
        self,
        layer: Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        if (
            hidden_size % MXFP8_BLOCK_SIZE != 0
            or intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0
        ):
            raise ValueError(
                "Online MXFP8 MoE requires hidden/intermediate sizes divisible "
                f"by {MXFP8_BLOCK_SIZE}."
            )

        super().create_weights(
            layer=layer,
            num_experts=num_experts,
            hidden_size=hidden_size,
            intermediate_size_per_partition=intermediate_size_per_partition,
            params_dtype=params_dtype,
            **extra_weight_attrs,
        )

        w13_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition,
                hidden_size // MXFP8_BLOCK_SIZE,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        w2_weight_scale = torch.nn.Parameter(
            torch.zeros(
                num_experts,
                hidden_size,
                intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_weight_scale", w13_weight_scale)
        layer.register_parameter("w2_weight_scale", w2_weight_scale)
        set_weight_attrs(w13_weight_scale, extra_weight_attrs)
        set_weight_attrs(w2_weight_scale, extra_weight_attrs)
        layer.weight_block_size = [1, MXFP8_BLOCK_SIZE]

    def _quantize_mxfp8_moe_weight(
        self, weight: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
        num_batches = weight.size(0)
        w_quant = []
        w_scales = []
        for i in range(num_batches):
            mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
                weight[i], is_sf_swizzled_layout=False
            )
            w_quant.append(mx_fp8_quant)
            w_scales.append(mx_fp8_scale)

        return torch.stack(w_quant), torch.stack(w_scales)

    def process_weights_after_loading(self, layer: Module) -> None:
        if getattr(layer, "_already_called_process_weights_after_loading", False):
            return

        if layer.w13_weight.device == torch.device("meta"):
            w13_weight = torch.nn.Parameter(
                torch.empty_like(layer.w13_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w13_weight, w13_weight)
            layer.register_parameter("w13_weight", w13_weight)
            initialize_single_dummy_weight(layer.w13_weight)
        if layer.w2_weight.device == torch.device("meta"):
            w2_weight = torch.nn.Parameter(
                torch.empty_like(layer.w2_weight, device=layer._load_device),
                requires_grad=False,
            )
            set_weight_attrs(
                w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
            )
            _copy_missing_attrs(layer.w2_weight, w2_weight)
            layer.register_parameter("w2_weight", w2_weight)
            initialize_single_dummy_weight(layer.w2_weight)

        fp8_dtype = current_platform.fp8_dtype()
        w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
        w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
        w13_scale = layer.w13_weight_scale
        w2_scale = layer.w2_weight_scale

        w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight)
        w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight)

        self._setup_kernel(
            layer,
            w13,
            w2,
            w13_scale,
            w2_scale,
            layer.w13_input_scale,
            layer.w2_input_scale,
        )

        layer._already_called_process_weights_after_loading = True

_quantize_mxfp8_moe_weight

_quantize_mxfp8_moe_weight(
    weight: Tensor,
) -> tuple[Tensor, Tensor]

Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales).

Source code in vllm/model_executor/layers/quantization/mxfp8.py
def _quantize_mxfp8_moe_weight(
    self, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
    num_batches = weight.size(0)
    w_quant = []
    w_scales = []
    for i in range(num_batches):
        mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
            weight[i], is_sf_swizzled_layout=False
        )
        w_quant.append(mx_fp8_quant)
        w_scales.append(mx_fp8_scale)

    return torch.stack(w_quant), torch.stack(w_scales)