Skip to content

vllm.model_executor.models.radio

RadioInternVisionModel

Bases: Module

Source code in vllm/model_executor/models/radio.py
class RadioInternVisionModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig = None,
        quant_config: QuantizationConfig | None = None,
        *,
        num_hidden_layers_override: int | None = None,
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.img_size, self.grid_size, self.num_patches = self._init_img_size(
            to_2tuple(config.patch_size), config.image_size
        )
        max_img_size = int(
            round(config.cpe_max_size / config.patch_size) * config.patch_size
        )
        self.temporal_patch_size = config.video_temporal_patch_size
        unique_teachers = set(t["name"] for t in config.teachers)
        self.patch_generator = ViTPatchGenerator(
            config.patch_size,
            config.hidden_size,
            input_dims=self.img_size,
            max_input_dims=max_img_size,
            cls_token=True,
            num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
            register_multiple=config.register_multiple,
            temporal_patch_size=self.temporal_patch_size,
            separate_video_embedder=config.separate_video_embedder,
        )

        self.encoder = RadioVisionEncoder(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
            prefix=f"{prefix}.encoder",
        )

    def _init_img_size(self, patch_size, img_size: int | tuple[int, int]):
        if img_size is None:
            return None, None, None
        img_size = to_2tuple(img_size)
        grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
        num_patches = grid_size[0] * grid_size[1]
        return img_size, grid_size, num_patches

    def get_input_embeddings(self):
        return self.embeddings

    def inter_image_mask_metadata(
        self, imgs_sizes: list[tuple[int, int]], device: torch.device
    ) -> MaskMetadata:
        """Build mask metadata from image pixel sizes. Adds num_skip to each
        sequence length (cls/register tokens) to match patch generator output."""
        patch_size = self.patch_generator.patch_size
        num_skip = self.patch_generator.num_skip

        seq_lens = calc_seq_lens(imgs_sizes, patch_size)
        adjusted = [s + num_skip for s in seq_lens]
        return self._inter_image_mask_metadata_from_seq_lens(adjusted, device=device)

    def _inter_image_mask_metadata_from_seq_lens(
        self, seq_lens: list[int], device: torch.device
    ) -> MaskMetadata:
        """Build mask metadata from actual sequence lengths (already including
        cls/register tokens, i.e. patch_count + num_skip per item).
        Use inter_image_mask_metadata() when you only have imgs_sizes."""
        assert len(seq_lens) > 0
        cu_seqlens = torch.tensor(
            list(accumulate(seq_lens, initial=0)), dtype=torch.int32, device=device
        )
        # Keep max_seqlen on CPU to avoid .item() sync
        # See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
        max_seqlen = torch.tensor(max(seq_lens), dtype=torch.int32)
        return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)

    def forward(
        self,
        x: torch.Tensor,
        imgs_sizes: list[tuple[int, int]] | None = None,
        num_frames: int | None = None,
    ) -> torch.FloatTensor:
        T = self.temporal_patch_size

        # Build packed-sequence metadata for MMEncoderAttention when needed.
        mask_meta = None
        packed_batch_size = None  # Original batch size before packing

        if num_frames is not None and T > 1:
            # Conv3d video: all tubelets have the same sequence length.
            # Pack [num_tubelets, seq_per_tubelet, hidden] → [1, total, hidden]
            hidden_states = self.patch_generator.forward_video(x)
            packed_batch_size, seq_per_tubelet, hidden_dim = hidden_states.shape
            hidden_states = hidden_states.reshape(1, -1, hidden_dim)
            mask_meta = self._inter_image_mask_metadata_from_seq_lens(
                [seq_per_tubelet] * packed_batch_size, device=hidden_states.device
            )
        else:
            # Images for any model, or video for non-conv3d model
            hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
            if imgs_sizes is not None and len(imgs_sizes) > 1:
                # Dynamic resolution w/ > 1 image, create attn mask
                mask_meta = self.inter_image_mask_metadata(
                    imgs_sizes, device=hidden_states.device
                )

        encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)

        # Unpack back to original batch shape if we packed for video
        if packed_batch_size is not None:
            encoder_outputs = encoder_outputs.reshape(
                packed_batch_size, seq_per_tubelet, -1
            )

        return encoder_outputs

_inter_image_mask_metadata_from_seq_lens

_inter_image_mask_metadata_from_seq_lens(
    seq_lens: list[int], device: device
) -> MaskMetadata

Build mask metadata from actual sequence lengths (already including cls/register tokens, i.e. patch_count + num_skip per item). Use inter_image_mask_metadata() when you only have imgs_sizes.

Source code in vllm/model_executor/models/radio.py
def _inter_image_mask_metadata_from_seq_lens(
    self, seq_lens: list[int], device: torch.device
) -> MaskMetadata:
    """Build mask metadata from actual sequence lengths (already including
    cls/register tokens, i.e. patch_count + num_skip per item).
    Use inter_image_mask_metadata() when you only have imgs_sizes."""
    assert len(seq_lens) > 0
    cu_seqlens = torch.tensor(
        list(accumulate(seq_lens, initial=0)), dtype=torch.int32, device=device
    )
    # Keep max_seqlen on CPU to avoid .item() sync
    # See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
    max_seqlen = torch.tensor(max(seq_lens), dtype=torch.int32)
    return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)

inter_image_mask_metadata

inter_image_mask_metadata(
    imgs_sizes: list[tuple[int, int]], device: device
) -> MaskMetadata

Build mask metadata from image pixel sizes. Adds num_skip to each sequence length (cls/register tokens) to match patch generator output.

Source code in vllm/model_executor/models/radio.py
def inter_image_mask_metadata(
    self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> MaskMetadata:
    """Build mask metadata from image pixel sizes. Adds num_skip to each
    sequence length (cls/register tokens) to match patch generator output."""
    patch_size = self.patch_generator.patch_size
    num_skip = self.patch_generator.num_skip

    seq_lens = calc_seq_lens(imgs_sizes, patch_size)
    adjusted = [s + num_skip for s in seq_lens]
    return self._inter_image_mask_metadata_from_seq_lens(adjusted, device=device)

ViTPatchGenerator

Bases: Module

Source code in vllm/model_executor/models/radio.py
class ViTPatchGenerator(nn.Module):
    def __init__(
        self,
        #  config: PretrainedConfig,
        patch_size: int,
        embed_dim: int,
        input_dims: input_dim_t,
        abs_pos: bool = True,
        normalize_patches: bool = False,
        cls_token: bool = False,
        max_input_dims: input_dim_t | None = None,
        pos_dropout: float = 0.0,
        return_pos_enc: bool = False,
        num_cls_tokens: int = 1,
        register_multiple: int | None = None,
        num_registers: int | None = None,
        patch_bias: bool = False,
        temporal_patch_size: int = 1,
        separate_video_embedder: bool = True,
        device=None,
        dtype=None,
    ):
        super().__init__()
        if isinstance(input_dims, int):
            input_dims = (input_dims, input_dims)

        if max_input_dims is None:
            max_input_dims = input_dims
        if isinstance(max_input_dims, int):
            max_input_dims = (max_input_dims, max_input_dims)

        max_input_dims = tuple(
            int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
        )

        self.cpe_mode = max_input_dims != input_dims
        self.pos_dropout = pos_dropout
        self.return_pos_enc = return_pos_enc

        factory = dict(device=device, dtype=dtype)

        self.patch_size = patch_size
        self.abs_pos = abs_pos
        self.embed_dim = embed_dim
        self.temporal_patch_size = temporal_patch_size

        self.num_rows = max_input_dims[0] // patch_size
        self.num_cols = max_input_dims[1] // patch_size
        self.input_dims = tuple(d // patch_size for d in input_dims)
        self.num_patches = self.num_rows * self.num_cols
        self.max_input_dims = max_input_dims

        self.im_to_patches = Im2Patches(patch_size)
        self.embedder = ViTPatchLinear(
            patch_size, embed_dim, bias=patch_bias, **factory
        )

        if temporal_patch_size > 1:
            if not separate_video_embedder:
                raise NotImplementedError(
                    "Only separate_video_embedder=True is supported for"
                    " temporal compression (temporal_patch_size > 1)"
                )
            self.video_embedder = ViTPatchLinear(
                patch_size,
                embed_dim,
                bias=patch_bias,
                temporal_patch_size=temporal_patch_size,
                **factory,
            )
            self._video_embedder_loaded = False

        if abs_pos:
            scale = embed_dim**-0.5
            self.pos_embed = nn.Parameter(
                torch.randn(1, self.num_patches, embed_dim, **factory) * scale
            )

        self.cls_token = ClsToken(
            embed_dim,
            num_tokens=num_cls_tokens,
            enabled=cls_token,
            register_multiple=register_multiple,
            num_registers=num_registers,
        )

        self.patch_normalizer = (
            nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
        )

    def forward(
        self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
    ) -> torch.Tensor:
        if imgs_sizes is not None:
            patches = self.embedder(x)
            patches, pos_enc = self.apply_pos_enc_dynamic(
                patches, imgs_sizes=imgs_sizes
            )
            patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
        else:
            patches = self.embed_patches(x)
            patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
            patches = self.cls_token(patches)
        patches = self.patch_normalizer(patches)
        if self.return_pos_enc:
            return patches, pos_enc
        return patches

    def forward_video(self, x: torch.Tensor) -> torch.Tensor:
        """Process video frames with temporal compression.

        Groups T consecutive frames into tubelets before embedding.

        Args:
            x: [num_frames, 3, H, W] tensor of video frames

        Returns:
            Embedded patches with temporal compression applied.
        """
        if not self._video_embedder_loaded:
            raise ValueError(
                "Temporal compression (video_temporal_patch_size > 1) requires "
                "video_embedder weights, but they were never loaded. "
                "Ensure the checkpoint was trained with temporal compression."
            )
        T = self.temporal_patch_size
        input_size = x.shape[2:]

        patches = self.im_to_patches(x)  # [N, num_patches, 3*P*P]
        num_frames, num_spatial, feat_dim = patches.shape

        # Pad to a multiple of T by repeating the last frame so that
        # all tubelets have exactly T frames.
        num_pad_frames = (-num_frames) % T
        if num_pad_frames > 0:
            last_frame_dup = patches[-1:].expand(num_pad_frames, -1, -1)
            patches = torch.cat([patches, last_frame_dup], dim=0)

        # Group T frames per tubelet: for each spatial position, concatenate
        #   features across T consecutive frames; order follows Megatron training
        num_frames_padded = patches.shape[0]
        num_tublets = num_frames_padded // T
        patches = rearrange(
            patches,
            "(tubelets frames) spatial feat -> tubelets spatial (frames feat)",
            tubelets=num_tublets,
            frames=T,
            spatial=num_spatial,
            feat=feat_dim,
        )

        patches = self.video_embedder(patches)

        patches, pos_enc = self.apply_pos_enc(patches, input_size=input_size)

        patches = self.cls_token(patches)

        patches = self.patch_normalizer(patches)
        if self.return_pos_enc:
            return patches, pos_enc
        return patches

    def apply_pos_enc_dynamic(
        self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        if not self.abs_pos:
            return patches, None

        current_length = 0
        pos_enc_list = []

        for size in imgs_sizes:
            seq_length = calc_seq_len(size, self.patch_size)

            img_patches = patches[:, current_length : current_length + seq_length, :]
            pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
            img_patches_with_pos = img_patches + pos_enc

            patches = torch.cat(
                [
                    patches[:, :current_length, :],
                    img_patches_with_pos,
                    patches[:, current_length + seq_length :, :],
                ],
                dim=1,
            )
            pos_enc_list.append(pos_enc)
            current_length += seq_length

        full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
        return patches, full_pos_enc

    def cls_token_dynamic(
        self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
    ) -> torch.Tensor:
        if not self.cls_token.enabled:
            return patches

        out = []
        current_length = 0

        for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
            class_token = self.cls_token.token.unsqueeze(0).expand(
                patches.shape[0], -1, -1
            )
            out.append(class_token)
            out.append(patches[:, current_length : current_length + seq_len, :])
            current_length += seq_len

        return torch.cat(out, dim=1)

    @property
    def apply_cls_token(self):
        return self.cls_token.enabled

    @property
    def num_cls_tokens(self):
        return self.cls_token.num_tokens

    @property
    def num_cls_patches(self):
        return self.cls_token.num_patches

    @property
    def num_registers(self):
        return self.cls_token.num_registers

    @property
    def num_skip(self):
        return self.num_cls_tokens + self.num_registers

    def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
        if src_embed.shape != targ_embed.shape:
            src_size = int(math.sqrt(src_embed.shape[1]))

            assert src_size**2 == src_embed.shape[1], (
                "Unable to interpolate non-square embedding"
            )

            src_embed = rearrange(
                src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
            )
            src_embed = F.interpolate(
                src_embed,
                size=(self.num_rows, self.num_cols),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
        targ_embed.data.copy_(src_embed)

    def _load_projection(
        self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
    ):
        if src_proj_weight.shape != targ_proj_weight.shape:
            src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))

            assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
                "Unable to interpolate non-square patch size"
            )

            src_proj_weight = rearrange(
                src_proj_weight,
                "b (c h w) -> b c h w",
                c=3,
                h=src_patch_size,
                w=src_patch_size,
            )
            src_proj_weight = F.interpolate(
                src_proj_weight,
                size=(self.patch_size, self.patch_size),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
        targ_proj_weight.data.copy_(src_proj_weight)

    def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
        patches = self.im_to_patches(x)
        patches = self.embedder(patches)
        return patches

    def apply_pos_enc(
        self,
        patches: torch.Tensor,
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
    ) -> torch.Tensor:
        if not self.abs_pos:
            return patches

        pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)

        if self.training and self.pos_dropout > 0:
            keeps = (
                torch.rand(
                    patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
                )
                > self.pos_dropout
            )
            pos_enc_drop = torch.where(keeps, pos_enc, 0)
        else:
            pos_enc_drop = pos_enc

        return patches + pos_enc_drop, pos_enc

    def get_pos_enc(
        self,
        batch_size: int,
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
    ) -> torch.Tensor:
        if input_size is None:
            input_dims = self.input_dims
        else:
            input_dims = tuple(d // self.patch_size for d in input_size)

        pos_embed = self._get_pos_embeddings(batch_size, input_dims)

        if patch_idxs is None:
            return pos_embed

        exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])

        pos_embed = torch.gather(
            pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
        )
        return pos_embed

    def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
        if (self.num_rows, self.num_cols) == input_dims:
            return self.pos_embed

        pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
            0, 3, 1, 2
        )

        def window_select(pos_embed):
            if input_dims[0] < pos_embed.shape[-2]:
                pos_embed = pos_embed[..., : input_dims[0], :]
            if input_dims[1] < pos_embed.shape[-1]:
                pos_embed = pos_embed[..., :, : input_dims[1]]
            return pos_embed

        if self.cpe_mode:
            max_dim = max(input_dims)
            pos_embed = F.interpolate(
                pos_embed.float(),
                size=(max_dim, max_dim),
                align_corners=False,
                mode="bilinear",
            ).to(pos_embed.dtype)

            pos_embed = window_select(pos_embed)
        else:
            pos_embed = window_select(pos_embed)

        if pos_embed.shape[-2:] != input_dims:
            pos_embed = F.interpolate(
                pos_embed.float(), size=input_dims, align_corners=False, mode="bilinear"
            ).to(pos_embed.dtype)

        pos_embed = pos_embed.flatten(2).permute(0, 2, 1)

        return pos_embed

forward_video

forward_video(x: Tensor) -> Tensor

Process video frames with temporal compression.

Groups T consecutive frames into tubelets before embedding.

Parameters:

Name Type Description Default
x Tensor

[num_frames, 3, H, W] tensor of video frames

required

Returns:

Type Description
Tensor

Embedded patches with temporal compression applied.

Source code in vllm/model_executor/models/radio.py
def forward_video(self, x: torch.Tensor) -> torch.Tensor:
    """Process video frames with temporal compression.

    Groups T consecutive frames into tubelets before embedding.

    Args:
        x: [num_frames, 3, H, W] tensor of video frames

    Returns:
        Embedded patches with temporal compression applied.
    """
    if not self._video_embedder_loaded:
        raise ValueError(
            "Temporal compression (video_temporal_patch_size > 1) requires "
            "video_embedder weights, but they were never loaded. "
            "Ensure the checkpoint was trained with temporal compression."
        )
    T = self.temporal_patch_size
    input_size = x.shape[2:]

    patches = self.im_to_patches(x)  # [N, num_patches, 3*P*P]
    num_frames, num_spatial, feat_dim = patches.shape

    # Pad to a multiple of T by repeating the last frame so that
    # all tubelets have exactly T frames.
    num_pad_frames = (-num_frames) % T
    if num_pad_frames > 0:
        last_frame_dup = patches[-1:].expand(num_pad_frames, -1, -1)
        patches = torch.cat([patches, last_frame_dup], dim=0)

    # Group T frames per tubelet: for each spatial position, concatenate
    #   features across T consecutive frames; order follows Megatron training
    num_frames_padded = patches.shape[0]
    num_tublets = num_frames_padded // T
    patches = rearrange(
        patches,
        "(tubelets frames) spatial feat -> tubelets spatial (frames feat)",
        tubelets=num_tublets,
        frames=T,
        spatial=num_spatial,
        feat=feat_dim,
    )

    patches = self.video_embedder(patches)

    patches, pos_enc = self.apply_pos_enc(patches, input_size=input_size)

    patches = self.cls_token(patches)

    patches = self.patch_normalizer(patches)
    if self.return_pos_enc:
        return patches, pos_enc
    return patches