Skip to content

vllm.model_executor.models.transformers.base

Transformers modeling backend base class.

Base

Bases: Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3

Source code in vllm/model_executor/models/transformers/base.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
class Base(
    nn.Module,
    VllmModel,
    SupportsQuant,
    SupportsLoRA,
    SupportsPP,
    SupportsEagle,
    SupportsEagle3,
):
    embedding_modules = ["embed_tokens"]  # TODO transformers will have a util to get it

    def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
        super().__init__()
        logger.info("Using Transformers modeling backend.")

        self.config = vllm_config.model_config.hf_config
        self.text_config = self.config.get_text_config()
        self.cache_config = vllm_config.cache_config
        self.device_config = vllm_config.device_config
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.quant_config = vllm_config.quant_config

        self.pp_group = get_pp_group()
        self.tp_group = get_tp_group()

        # Attrs for weight loading (see self.load_weights)
        self.skip_prefixes: list[str] = []
        """Skip loading weights whose qualname starts with these prefixes."""
        self.skip_substrs: list[str] = []
        """Skip loading weights whose qualname contains these substrings."""
        self.ignore_unexpected_prefixes: list[str] = []
        """Ignore unexpected weights whose qualname starts with these prefixes."""
        self.ignore_unexpected_suffixes: list[str] = []
        """Ignore unexpected weights whose qualname ends with these suffixes."""

        # Attrs for Eagle3 (see self.set_aux_hidden_state_layers)
        self._target_class: type[nn.Module] = nn.Module
        """Target class for Eagle3 aux hidden state recording."""
        self._layer_names: dict[int, str] = {}
        """Mapping from layer index to layer name for Eagle3."""
        self._output_aux_hidden_states_kwargs: dict[str, bool] = {}
        """Kwargs to pass to model forward for Eagle3 aux hidden states."""

        if self.quant_config:
            quant_method_name = self.quant_config.get_name()
            # Check for unsupported quantization methods.
            if quant_method_name == "mxfp4":
                raise NotImplementedError(
                    "Transformers modeling backend does "
                    "not support MXFP4 quantization yet."
                )
            # Skip loading extra bias for GPTQ models.
            if "gptq" in quant_method_name:
                self.ignore_unexpected_suffixes.append(".bias")

        # Patch config and init on "meta" to delay allocating GPU tensors
        self._patch_config()
        with init_on_device_without_buffers("meta"):
            self.model: PreTrainedModel = AutoModel.from_config(
                self.config,
                dtype=self.model_config.dtype,
                trust_remote_code=self.model_config.trust_remote_code,
            )

        # Create weight name to module qualname mapper
        self._create_hf_to_vllm_mapper()
        # Remove layers not on this pipeline parallel rank
        self.pipeline_parallel()
        # Substitute remaining layers with vLLM's layers as needed
        self.recursive_replace()
        # Create attention instances for KV cache allocation
        self.attention_instances = self.create_attention_instances()

        # Input embeddings
        self.embed_scale = None
        input_embeddings = self.model.get_input_embeddings()
        if not isinstance(input_embeddings, PPMissingLayer):
            # Some models scale embeddings inside the input embedding layer
            self.embed_scale = getattr(input_embeddings, "embed_scale", None)
            names = ("embedding_size", "hidden_size")
            embedding_dim = getattr_iter(self.text_config, names, None)
            assert embedding_dim is not None
            self.model.set_input_embeddings(
                VocabParallelEmbedding(
                    self.text_config.vocab_size,
                    embedding_dim=embedding_dim,
                    org_num_embeddings=self.text_config.vocab_size,
                    quant_config=self.quant_config,
                )
            )

        # Initialize any parameters that have not had their modules replaced
        self.init_parameters(self.model)

        # Pipeline parallel intermediate tensors
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], self.text_config.hidden_size
        )

    def _patch_config(self):
        """
        Patch the config to ensure that the model is created correctly:

        - Sets the attention implementation to "vllm" so the attention instances from
        `create_attention_instances` are used
        - Sets the dtype to the default torch dtype set by vLLM because Transformers
        uses the config dtype when creating the model
        - Propagates this dtype to any sub-configs because Transformers model
        implementations do not support/use different dtypes in sub-models
        """
        self.text_config._attn_implementation = "vllm"
        self.config.dtype = torch.get_default_dtype()
        # TODO(hmellor): Remove this when Transformers v4 support is dropped
        for sub_config_name in getattr(self.config, "sub_configs", {}):
            sub_config = getattr(self.config, sub_config_name)
            if sub_config.dtype != (dtype := self.config.dtype):
                sub_config.dtype = dtype

    def _create_hf_to_vllm_mapper(self):
        """
        Create a WeightsMapper to map checkpoint weight names to module qualnames.

        This handles:

        - Transformers weight renaming:
            - from `WeightRenaming` in Transformers v5
            - from `_checkpoint_conversion_mapping` in Transformers v4
        - Checkpoints saved with a base model prefix that is not `model`
        - Checkpoints saved with no base model prefix
        - Any quantization config specific mappings
        """
        self.hf_to_vllm_mapper = WeightsMapper()
        orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex

        if Version(transformers.__version__) >= Version("5.0.0"):
            from transformers.conversion_mapping import (
                WeightRenaming,
                get_model_conversion_mapping,
            )

            for mapping in get_model_conversion_mapping(self.model):
                # Handle weights which have been renamed in Transformers
                if isinstance(mapping, WeightRenaming):
                    # Recompile using regex (Transformers used re)
                    compiled_sources = re.compile(
                        mapping.compiled_sources.pattern, mapping.compiled_sources.flags
                    )
                    target_pattern = mapping.target_patterns[0]
                    orig_to_new_regex[compiled_sources] = target_pattern
                # TODO: Handle WeightConverter to enable layer merging
        else:
            # Replace legacy suffixes used for norms
            # TODO(hmellor): Remove this when Transformers v4 support is dropped
            orig_to_new_regex.update(
                {
                    re.compile(r"\.gamma$"): ".weight",
                    re.compile(r"\.beta$"): ".bias",
                }
            )

        # Handle weights which have been renamed in Transformers
        # TODO(hmellor): Remove this when Transformers v4 support is dropped
        ccm = getattr(self.model, "_checkpoint_conversion_mapping", {})
        for source, target in ccm.items():
            orig_to_new_regex[re.compile(source)] = target

        # Handle unexpected weights which should be ignored
        if self.model._keys_to_ignore_on_load_unexpected is not None:
            for key in self.model._keys_to_ignore_on_load_unexpected:
                orig_to_new_regex[re.compile(key)] = None

        # Standardise base model prefix
        bmp = self.model.base_model_prefix
        expected_bmp = r"model.\1"
        # Handle checkpoints saved with different base model prefix
        if bmp and bmp != "model":
            different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)")
            orig_to_new_regex[different_bmp_pattern] = expected_bmp
        # Handle direct children of self.model which were saved without the model prefix
        direct_children = chain(
            self.model.named_children(),
            self.model.named_parameters(recurse=False),
            self.model.named_buffers(recurse=False),
        )
        model_children = "|".join(name for name, _ in direct_children)
        missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)")
        orig_to_new_regex[missing_bmp_pattern] = expected_bmp
        # Handle weights saved as direct children of self.model which no longer are
        unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)")
        orig_to_new_regex[unexpected_bmp_pattern] = r"\2"
        # Handle lm_head which was saved inside the base model
        nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)")
        orig_to_new_regex[nested_lm_head_pattern] = r"\2"

        # Apply mapping to quantization config if needed
        self._maybe_apply_model_mapping()

    def _get_tie_word_embeddings(self):
        """
        Check if the model has tied word embeddings.
        """
        # Transformers v4 and v5 will store this in different places
        tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False)
        tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False)
        return tie_word_embeddings_v4 or tie_word_embeddings_v5

    def pipeline_parallel(self):
        """
        Apply the model's pipeline parallelization plan.
        """
        if self.pp_group.world_size <= 1:
            return

        if not self.model.supports_pp_plan:
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
            raise ValueError(
                f"{type(self.model)} does not support pipeline parallel. {tip}"
            )

        def attrsetter(attr: str) -> Callable[[object, object], None]:
            """Set a possibly nested attribute, like the inverse of attrgetter."""
            parent, _, name = attr.rpartition(".")

            def setter(obj: object, value: object):
                attr_parent = attrgetter(parent)(obj) if parent else obj
                setattr(attr_parent, name, value)

            return setter

        module_lists = []
        module_list_idx = None
        pp_plan = list(self.model._pp_plan.keys())
        for i, name in enumerate(pp_plan):
            # attrgetter in case the module is nested (e.g. "text_model.layers")
            if isinstance(attrgetter(name)(self.model), nn.ModuleList):
                module_lists.append(name)
                module_list_idx = i

        if len(module_lists) > 1:
            raise ValueError(
                "Pipeline parallel of models with multiple `ModuleList`s "
                "in the base model are not supported yet!"
            )
        if module_list_idx is None:
            raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")

        # Layers before module list
        for name in pp_plan[:module_list_idx]:
            if self.pp_group.is_first_rank or (
                self._get_tie_word_embeddings() and self.pp_group.is_last_rank
            ):
                continue
            # attrsetter in case the module is nested (e.g. "text_model.embed_tokens")
            attrsetter(name)(self.model, PPMissingLayer())

        # Module list
        start_layer, end_layer = get_pp_indices(
            self.text_config.num_hidden_layers,
            self.pp_group.rank_in_group,
            self.pp_group.world_size,
        )
        layers_name = pp_plan[module_list_idx]
        # attrgetter in case the module is nested (e.g. "text_model.layers")
        layers = attrgetter(layers_name)(self.model)
        for i in range(len(layers)):
            if start_layer <= i and i < end_layer:
                continue
            layers[i] = PPMissingLayer()

        # Layers after module list
        for name in pp_plan[module_list_idx + 1 :]:
            # Modules that should be on last rank
            if not self.pp_group.is_last_rank:
                # attrsetter in case the module is nested (e.g. "text_model.norm")
                attrsetter(name)(self.model, PPMissingLayer())

    def recursive_replace(self):
        """Recursively replace modules in the model as needed.

        Currently, this replaces:

        - `nn.Linear` with vLLM's tensor parallel linear classes
        - `*RMSNorm` with vLLM's `RMSNorm`
        """
        tp_plan = self.model.tp_plan

        if not tp_plan and self.tp_group.world_size > 1:
            tip = get_feature_request_tip(
                self.model_config.model, self.model_config.trust_remote_code
            )
            raise ValueError(
                f"{type(self.model)} does not support tensor parallel. {tip}"
            )

        # Prefix the patterns because we always start from `self.model`
        tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

        def _recursive_replace(module: nn.Module, prefix: str):
            for child_name, child_module in module.named_children():
                new_module = child_module
                qual_name = maybe_prefix(prefix, child_name)
                if (
                    isinstance(module, nn.ModuleList)
                    and len(module) == self.text_config.num_hidden_layers
                ):
                    # Populate Eagle3 attrs
                    self._target_class = type(child_module)
                    layer_name = qual_name.removeprefix("model.")
                    self._layer_names[int(child_name)] = layer_name
                    # MTP weights should not be loaded into the base model
                    num_hidden_layers = self.text_config.num_hidden_layers
                    names = (
                        "n_predict",  # Override from SpeculativeConfig
                        "num_nextn_predict_layers",  # Most models
                        "mtp_num_hidden_layers",  # Qwen 3.5
                    )
                    n_predict = getattr_iter(self.text_config, names, 0)
                    for i in range(num_hidden_layers, num_hidden_layers + n_predict):
                        mtp_prefix = f"{prefix}.{i}."
                        if mtp_prefix not in self.ignore_unexpected_prefixes:
                            self.ignore_unexpected_prefixes.append(mtp_prefix)
                # Replace modules as needed
                if isinstance(child_module, nn.Linear):
                    generator = (p for p in tp_plan if re.match(p, qual_name))
                    pattern = next(generator, None)
                    # Some weight loaders expect all linear layers to inherit
                    # LinearBase, so we set a default style which causes any
                    # unspecified layers to be replaced with ReplicatedLinear
                    style = tp_plan.get(pattern, "replicate")
                    new_module = replace_linear_class(
                        child_module, style, self.quant_config, prefix=qual_name
                    )
                elif isinstance(child_module, (nn.Conv2d, nn.Conv3d)):
                    new_module = replace_conv_class(child_module)
                elif child_module.__class__.__name__.endswith("RMSNorm"):
                    new_module = replace_rms_norm_class(
                        child_module, self.text_config.hidden_size
                    )
                else:
                    _recursive_replace(child_module, prefix=qual_name)

                if new_module is not child_module:
                    setattr(module, child_name, new_module)
                    log_replacement(qual_name, child_module, new_module)

        _recursive_replace(self.model, prefix="model")

    def create_attention_instances(self) -> dict[int, Attention]:
        """
        Create `Attention` instances to inform KV cache allocation.
        """
        text_config = self.text_config

        num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
        head_size = self.model_config.get_head_size()
        num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
        logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None)

        # In encoder models, the attention layers will have `is_causal=False`
        is_encoder = lambda module: not getattr(module, "is_causal", True)
        has_encoder = lambda model: any(is_encoder(m) for m in model.modules())
        is_multimodal = lambda config: config != config.get_text_config()
        # vLLM does not support encoder-decoder models, so if any encoder layer is
        # found in a text only model, we assume the whole model is an encoder model
        if has_encoder(self.model) and not is_multimodal(self.config):
            self.check_version("5.0.0", "encoder models support")
            attn_type = AttentionType.ENCODER_ONLY
        else:
            attn_type = AttentionType.DECODER

        pp_rank = self.pp_group.rank_in_group
        pp_size = self.pp_group.world_size
        start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size)

        attention_instances = {}
        for i in range(start, end):
            # Handle interleaved sliding window attention
            per_layer_sliding_window = None
            if (
                hasattr(self.config, "layer_types")
                and self.config.layer_types[i] == "sliding_attention"
            ):
                per_layer_sliding_window = self.config.sliding_window

            attn_cls = (
                EncoderOnlyAttention
                if attn_type == AttentionType.ENCODER_ONLY
                else Attention
            )
            attention_instances[i] = attn_cls(
                num_heads=num_heads,
                head_size=head_size,
                # NOTE: We use Llama scale as default, if it's set by
                # Transformers, it's updated in vllm_flash_attention_forward
                scale=head_size**-0.5,
                num_kv_heads=num_kv_heads,
                cache_config=self.cache_config,
                quant_config=self.quant_config,
                logits_soft_cap=logits_soft_cap,
                per_layer_sliding_window=per_layer_sliding_window,
                prefix=f"{i}.attn",
                attn_type=attn_type,
            )
        return attention_instances

    def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
        """
        If a `parameter` is on the `meta` device, then its parent
        `module` is the original module created by:

        ```python
        with torch.device("meta"):
            self.model: "PreTrainedModel" = AutoModel.from_config(...)
        ```
        """

        def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
            for name, param in module.named_parameters(recurse=False):
                if param.device == torch.device("meta"):
                    new_param = nn.Parameter(
                        torch.empty_like(
                            param.data,
                            dtype=dtype or self.model_config.dtype,
                            device=self.device_config.device,
                        )
                    )
                    setattr(module, name, new_param)
            for child in module.children():
                _init_parameters(child, dtype)

        _init_parameters(module, dtype)

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        if self.embed_scale is not None:
            inputs_embeds *= self.embed_scale
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor | IntermediateTensors:
        if not self.pp_group.is_first_rank:
            assert intermediate_tensors is not None
            input_ids = None
            inputs_embeds = intermediate_tensors["hidden_states"]

        if input_ids is not None:
            input_ids = input_ids[None, ...]
        if inputs_embeds is not None:
            inputs_embeds = inputs_embeds[None, ...]

        # If the model scales embeddings inside the input embedding layer we must
        # ensure they are scaled here since VocabParallelEmbedding will not do it
        if (
            self.embed_scale is not None
            and input_ids is not None
            and inputs_embeds is None
        ):
            inputs_embeds = self.embed_input_ids(input_ids)
            input_ids = None

        if self.model_config.uses_mrope:
            position_ids = positions[:, None]
        else:
            position_ids = positions[None, ...]

        outputs = self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            use_cache=False,
            position_ids=position_ids,
            attention_instances=self.attention_instances,
            return_dict=False,
            **self._output_aux_hidden_states_kwargs,
            **kwargs,
        )
        # We must remove the batch dimension from these outputs
        hidden_states = outputs[0][0, ...]
        if self._output_aux_hidden_states_kwargs:
            aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]

        if not self.pp_group.is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
        return hidden_states

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=self.skip_prefixes,
            skip_substrs=self.skip_substrs,
            ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
            ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    @staticmethod
    def check_version(min_version: str, feature: str):
        installed = Version(transformers.__version__)
        required = Version(min_version)
        if installed < required:
            raise ImportError(
                f"Transformers modeling backend requires transformers>={required} "
                f"for {feature}, but got {installed}"
            )

    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.check_version("5.2.0", "Eagle3 support")
        from transformers.utils.output_capturing import (
            OutputRecorder,
            maybe_install_capturing_hooks,
        )

        # The default value in PreTrainedModel is None
        if self.model._can_record_outputs is None:
            self.model._can_record_outputs = {}

        target_class = self._target_class
        for layer in layers:
            # layer - 1 because we want the input to the layer
            layer_name = self._layer_names[layer - 1]
            layer_key = f"aux_hidden_state_{layer}"
            aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name)
            self.model._can_record_outputs[layer_key] = aux_hidden_state_i
            self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True

        # Ensure that the capture hooks are installed before dynamo traces the model
        maybe_install_capturing_hooks(self.model)

    def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = self.text_config.num_hidden_layers
        return (2, num_layers // 2, num_layers - 3)

_layer_names instance-attribute

_layer_names: dict[int, str] = {}

Mapping from layer index to layer name for Eagle3.

_output_aux_hidden_states_kwargs instance-attribute

_output_aux_hidden_states_kwargs: dict[str, bool] = {}

Kwargs to pass to model forward for Eagle3 aux hidden states.

_target_class instance-attribute

_target_class: type[Module] = Module

Target class for Eagle3 aux hidden state recording.

ignore_unexpected_prefixes instance-attribute

ignore_unexpected_prefixes: list[str] = []

Ignore unexpected weights whose qualname starts with these prefixes.

ignore_unexpected_suffixes instance-attribute

ignore_unexpected_suffixes: list[str] = []

Ignore unexpected weights whose qualname ends with these suffixes.

skip_prefixes instance-attribute

skip_prefixes: list[str] = []

Skip loading weights whose qualname starts with these prefixes.

skip_substrs instance-attribute

skip_substrs: list[str] = []

Skip loading weights whose qualname contains these substrings.

_create_hf_to_vllm_mapper

_create_hf_to_vllm_mapper()

Create a WeightsMapper to map checkpoint weight names to module qualnames.

This handles:

  • Transformers weight renaming:
    • from WeightRenaming in Transformers v5
    • from _checkpoint_conversion_mapping in Transformers v4
  • Checkpoints saved with a base model prefix that is not model
  • Checkpoints saved with no base model prefix
  • Any quantization config specific mappings
Source code in vllm/model_executor/models/transformers/base.py
def _create_hf_to_vllm_mapper(self):
    """
    Create a WeightsMapper to map checkpoint weight names to module qualnames.

    This handles:

    - Transformers weight renaming:
        - from `WeightRenaming` in Transformers v5
        - from `_checkpoint_conversion_mapping` in Transformers v4
    - Checkpoints saved with a base model prefix that is not `model`
    - Checkpoints saved with no base model prefix
    - Any quantization config specific mappings
    """
    self.hf_to_vllm_mapper = WeightsMapper()
    orig_to_new_regex = self.hf_to_vllm_mapper.orig_to_new_regex

    if Version(transformers.__version__) >= Version("5.0.0"):
        from transformers.conversion_mapping import (
            WeightRenaming,
            get_model_conversion_mapping,
        )

        for mapping in get_model_conversion_mapping(self.model):
            # Handle weights which have been renamed in Transformers
            if isinstance(mapping, WeightRenaming):
                # Recompile using regex (Transformers used re)
                compiled_sources = re.compile(
                    mapping.compiled_sources.pattern, mapping.compiled_sources.flags
                )
                target_pattern = mapping.target_patterns[0]
                orig_to_new_regex[compiled_sources] = target_pattern
            # TODO: Handle WeightConverter to enable layer merging
    else:
        # Replace legacy suffixes used for norms
        # TODO(hmellor): Remove this when Transformers v4 support is dropped
        orig_to_new_regex.update(
            {
                re.compile(r"\.gamma$"): ".weight",
                re.compile(r"\.beta$"): ".bias",
            }
        )

    # Handle weights which have been renamed in Transformers
    # TODO(hmellor): Remove this when Transformers v4 support is dropped
    ccm = getattr(self.model, "_checkpoint_conversion_mapping", {})
    for source, target in ccm.items():
        orig_to_new_regex[re.compile(source)] = target

    # Handle unexpected weights which should be ignored
    if self.model._keys_to_ignore_on_load_unexpected is not None:
        for key in self.model._keys_to_ignore_on_load_unexpected:
            orig_to_new_regex[re.compile(key)] = None

    # Standardise base model prefix
    bmp = self.model.base_model_prefix
    expected_bmp = r"model.\1"
    # Handle checkpoints saved with different base model prefix
    if bmp and bmp != "model":
        different_bmp_pattern = re.compile(rf"^{bmp}\.(.+)")
        orig_to_new_regex[different_bmp_pattern] = expected_bmp
    # Handle direct children of self.model which were saved without the model prefix
    direct_children = chain(
        self.model.named_children(),
        self.model.named_parameters(recurse=False),
        self.model.named_buffers(recurse=False),
    )
    model_children = "|".join(name for name, _ in direct_children)
    missing_bmp_pattern = re.compile(rf"^(?!model\.)(({model_children}).*)")
    orig_to_new_regex[missing_bmp_pattern] = expected_bmp
    # Handle weights saved as direct children of self.model which no longer are
    unexpected_bmp_pattern = re.compile(rf"^(model\.)((?!{model_children}).+)")
    orig_to_new_regex[unexpected_bmp_pattern] = r"\2"
    # Handle lm_head which was saved inside the base model
    nested_lm_head_pattern = re.compile(r"^model\.(.+\.)*(lm_head.+)")
    orig_to_new_regex[nested_lm_head_pattern] = r"\2"

    # Apply mapping to quantization config if needed
    self._maybe_apply_model_mapping()

_get_tie_word_embeddings

_get_tie_word_embeddings()

Check if the model has tied word embeddings.

Source code in vllm/model_executor/models/transformers/base.py
def _get_tie_word_embeddings(self):
    """
    Check if the model has tied word embeddings.
    """
    # Transformers v4 and v5 will store this in different places
    tie_word_embeddings_v4 = getattr(self.text_config, "tie_word_embeddings", False)
    tie_word_embeddings_v5 = getattr(self.config, "tie_word_embeddings", False)
    return tie_word_embeddings_v4 or tie_word_embeddings_v5

_patch_config

_patch_config()

Patch the config to ensure that the model is created correctly:

  • Sets the attention implementation to "vllm" so the attention instances from create_attention_instances are used
  • Sets the dtype to the default torch dtype set by vLLM because Transformers uses the config dtype when creating the model
  • Propagates this dtype to any sub-configs because Transformers model implementations do not support/use different dtypes in sub-models
Source code in vllm/model_executor/models/transformers/base.py
def _patch_config(self):
    """
    Patch the config to ensure that the model is created correctly:

    - Sets the attention implementation to "vllm" so the attention instances from
    `create_attention_instances` are used
    - Sets the dtype to the default torch dtype set by vLLM because Transformers
    uses the config dtype when creating the model
    - Propagates this dtype to any sub-configs because Transformers model
    implementations do not support/use different dtypes in sub-models
    """
    self.text_config._attn_implementation = "vllm"
    self.config.dtype = torch.get_default_dtype()
    # TODO(hmellor): Remove this when Transformers v4 support is dropped
    for sub_config_name in getattr(self.config, "sub_configs", {}):
        sub_config = getattr(self.config, sub_config_name)
        if sub_config.dtype != (dtype := self.config.dtype):
            sub_config.dtype = dtype

create_attention_instances

create_attention_instances() -> dict[int, Attention]

Create Attention instances to inform KV cache allocation.

Source code in vllm/model_executor/models/transformers/base.py
def create_attention_instances(self) -> dict[int, Attention]:
    """
    Create `Attention` instances to inform KV cache allocation.
    """
    text_config = self.text_config

    num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
    head_size = self.model_config.get_head_size()
    num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
    logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None)

    # In encoder models, the attention layers will have `is_causal=False`
    is_encoder = lambda module: not getattr(module, "is_causal", True)
    has_encoder = lambda model: any(is_encoder(m) for m in model.modules())
    is_multimodal = lambda config: config != config.get_text_config()
    # vLLM does not support encoder-decoder models, so if any encoder layer is
    # found in a text only model, we assume the whole model is an encoder model
    if has_encoder(self.model) and not is_multimodal(self.config):
        self.check_version("5.0.0", "encoder models support")
        attn_type = AttentionType.ENCODER_ONLY
    else:
        attn_type = AttentionType.DECODER

    pp_rank = self.pp_group.rank_in_group
    pp_size = self.pp_group.world_size
    start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size)

    attention_instances = {}
    for i in range(start, end):
        # Handle interleaved sliding window attention
        per_layer_sliding_window = None
        if (
            hasattr(self.config, "layer_types")
            and self.config.layer_types[i] == "sliding_attention"
        ):
            per_layer_sliding_window = self.config.sliding_window

        attn_cls = (
            EncoderOnlyAttention
            if attn_type == AttentionType.ENCODER_ONLY
            else Attention
        )
        attention_instances[i] = attn_cls(
            num_heads=num_heads,
            head_size=head_size,
            # NOTE: We use Llama scale as default, if it's set by
            # Transformers, it's updated in vllm_flash_attention_forward
            scale=head_size**-0.5,
            num_kv_heads=num_kv_heads,
            cache_config=self.cache_config,
            quant_config=self.quant_config,
            logits_soft_cap=logits_soft_cap,
            per_layer_sliding_window=per_layer_sliding_window,
            prefix=f"{i}.attn",
            attn_type=attn_type,
        )
    return attention_instances

init_parameters

init_parameters(module: Module, dtype: dtype | None = None)

If a parameter is on the meta device, then its parent module is the original module created by:

with torch.device("meta"):
    self.model: "PreTrainedModel" = AutoModel.from_config(...)
Source code in vllm/model_executor/models/transformers/base.py
def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None):
    """
    If a `parameter` is on the `meta` device, then its parent
    `module` is the original module created by:

    ```python
    with torch.device("meta"):
        self.model: "PreTrainedModel" = AutoModel.from_config(...)
    ```
    """

    def _init_parameters(module: nn.Module, dtype: torch.dtype | None):
        for name, param in module.named_parameters(recurse=False):
            if param.device == torch.device("meta"):
                new_param = nn.Parameter(
                    torch.empty_like(
                        param.data,
                        dtype=dtype or self.model_config.dtype,
                        device=self.device_config.device,
                    )
                )
                setattr(module, name, new_param)
        for child in module.children():
            _init_parameters(child, dtype)

    _init_parameters(module, dtype)

pipeline_parallel

pipeline_parallel()

Apply the model's pipeline parallelization plan.

Source code in vllm/model_executor/models/transformers/base.py
def pipeline_parallel(self):
    """
    Apply the model's pipeline parallelization plan.
    """
    if self.pp_group.world_size <= 1:
        return

    if not self.model.supports_pp_plan:
        tip = get_feature_request_tip(
            self.model_config.model, self.model_config.trust_remote_code
        )
        raise ValueError(
            f"{type(self.model)} does not support pipeline parallel. {tip}"
        )

    def attrsetter(attr: str) -> Callable[[object, object], None]:
        """Set a possibly nested attribute, like the inverse of attrgetter."""
        parent, _, name = attr.rpartition(".")

        def setter(obj: object, value: object):
            attr_parent = attrgetter(parent)(obj) if parent else obj
            setattr(attr_parent, name, value)

        return setter

    module_lists = []
    module_list_idx = None
    pp_plan = list(self.model._pp_plan.keys())
    for i, name in enumerate(pp_plan):
        # attrgetter in case the module is nested (e.g. "text_model.layers")
        if isinstance(attrgetter(name)(self.model), nn.ModuleList):
            module_lists.append(name)
            module_list_idx = i

    if len(module_lists) > 1:
        raise ValueError(
            "Pipeline parallel of models with multiple `ModuleList`s "
            "in the base model are not supported yet!"
        )
    if module_list_idx is None:
        raise ValueError(f"Could not find `ModuleList` in {type(self.model)}")

    # Layers before module list
    for name in pp_plan[:module_list_idx]:
        if self.pp_group.is_first_rank or (
            self._get_tie_word_embeddings() and self.pp_group.is_last_rank
        ):
            continue
        # attrsetter in case the module is nested (e.g. "text_model.embed_tokens")
        attrsetter(name)(self.model, PPMissingLayer())

    # Module list
    start_layer, end_layer = get_pp_indices(
        self.text_config.num_hidden_layers,
        self.pp_group.rank_in_group,
        self.pp_group.world_size,
    )
    layers_name = pp_plan[module_list_idx]
    # attrgetter in case the module is nested (e.g. "text_model.layers")
    layers = attrgetter(layers_name)(self.model)
    for i in range(len(layers)):
        if start_layer <= i and i < end_layer:
            continue
        layers[i] = PPMissingLayer()

    # Layers after module list
    for name in pp_plan[module_list_idx + 1 :]:
        # Modules that should be on last rank
        if not self.pp_group.is_last_rank:
            # attrsetter in case the module is nested (e.g. "text_model.norm")
            attrsetter(name)(self.model, PPMissingLayer())

recursive_replace

recursive_replace()

Recursively replace modules in the model as needed.

Currently, this replaces:

  • nn.Linear with vLLM's tensor parallel linear classes
  • *RMSNorm with vLLM's RMSNorm
Source code in vllm/model_executor/models/transformers/base.py
def recursive_replace(self):
    """Recursively replace modules in the model as needed.

    Currently, this replaces:

    - `nn.Linear` with vLLM's tensor parallel linear classes
    - `*RMSNorm` with vLLM's `RMSNorm`
    """
    tp_plan = self.model.tp_plan

    if not tp_plan and self.tp_group.world_size > 1:
        tip = get_feature_request_tip(
            self.model_config.model, self.model_config.trust_remote_code
        )
        raise ValueError(
            f"{type(self.model)} does not support tensor parallel. {tip}"
        )

    # Prefix the patterns because we always start from `self.model`
    tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()}

    def _recursive_replace(module: nn.Module, prefix: str):
        for child_name, child_module in module.named_children():
            new_module = child_module
            qual_name = maybe_prefix(prefix, child_name)
            if (
                isinstance(module, nn.ModuleList)
                and len(module) == self.text_config.num_hidden_layers
            ):
                # Populate Eagle3 attrs
                self._target_class = type(child_module)
                layer_name = qual_name.removeprefix("model.")
                self._layer_names[int(child_name)] = layer_name
                # MTP weights should not be loaded into the base model
                num_hidden_layers = self.text_config.num_hidden_layers
                names = (
                    "n_predict",  # Override from SpeculativeConfig
                    "num_nextn_predict_layers",  # Most models
                    "mtp_num_hidden_layers",  # Qwen 3.5
                )
                n_predict = getattr_iter(self.text_config, names, 0)
                for i in range(num_hidden_layers, num_hidden_layers + n_predict):
                    mtp_prefix = f"{prefix}.{i}."
                    if mtp_prefix not in self.ignore_unexpected_prefixes:
                        self.ignore_unexpected_prefixes.append(mtp_prefix)
            # Replace modules as needed
            if isinstance(child_module, nn.Linear):
                generator = (p for p in tp_plan if re.match(p, qual_name))
                pattern = next(generator, None)
                # Some weight loaders expect all linear layers to inherit
                # LinearBase, so we set a default style which causes any
                # unspecified layers to be replaced with ReplicatedLinear
                style = tp_plan.get(pattern, "replicate")
                new_module = replace_linear_class(
                    child_module, style, self.quant_config, prefix=qual_name
                )
            elif isinstance(child_module, (nn.Conv2d, nn.Conv3d)):
                new_module = replace_conv_class(child_module)
            elif child_module.__class__.__name__.endswith("RMSNorm"):
                new_module = replace_rms_norm_class(
                    child_module, self.text_config.hidden_size
                )
            else:
                _recursive_replace(child_module, prefix=qual_name)

            if new_module is not child_module:
                setattr(module, child_name, new_module)
                log_replacement(qual_name, child_module, new_module)

    _recursive_replace(self.model, prefix="model")