Skip to content

Commit 97d1c99

Browse files
authored
Rename clashing method names for vLLM model protocol (#27583)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 3226283 commit 97d1c99

File tree

164 files changed

+574
-583
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

164 files changed

+574
-583
lines changed

docs/contributing/model/basic.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ The initialization code should look like this:
5656

5757
### Computation Code
5858

59-
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
59+
- Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
6060

6161
```python
6262
class MyModel(nn.Module):
6363
...
6464

65-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
65+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
6666
...
6767
```
6868

docs/contributing/model/multimodal.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Further update the model as follows:
3636

3737
More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it.
3838

39-
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
39+
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
4040

4141
??? code
4242

@@ -49,7 +49,7 @@ Further update the model as follows:
4949
image_features = self.vision_encoder(image_input)
5050
return self.multi_modal_projector(image_features)
5151

52-
def get_multimodal_embeddings(
52+
def embed_multimodal(
5353
self,
5454
**kwargs: object,
5555
) -> MultiModalEmbeddings | None:
@@ -69,7 +69,7 @@ Further update the model as follows:
6969
!!! note
7070
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
7171
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
72-
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
72+
This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids].
7373

7474
You may override this method if additional logic is required for your model when merging embeddings.
7575

vllm/model_executor/models/apertus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(
382382
["hidden_states", "residual"], config.hidden_size
383383
)
384384

385-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
385+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
386386
return self.embed_tokens(input_ids)
387387

388388
def forward(
@@ -396,7 +396,7 @@ def forward(
396396
if inputs_embeds is not None:
397397
hidden_states = inputs_embeds
398398
else:
399-
hidden_states = self.get_input_embeddings(input_ids)
399+
hidden_states = self.embed_input_ids(input_ids)
400400
residual = None
401401
else:
402402
assert intermediate_tensors is not None
@@ -557,8 +557,8 @@ def _init_model(
557557
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
558558
)
559559

560-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
561-
return self.model.get_input_embeddings(input_ids)
560+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
561+
return self.model.embed_input_ids(input_ids)
562562

563563
def forward(
564564
self,

vllm/model_executor/models/arcee.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
["hidden_states", "residual"], config.hidden_size
240240
)
241241

242-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
242+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
243243
return self.embed_tokens(input_ids)
244244

245245
def forward(
@@ -254,7 +254,7 @@ def forward(
254254
hidden_states = (
255255
inputs_embeds
256256
if inputs_embeds is not None
257-
else self.get_input_embeddings(input_ids)
257+
else self.embed_input_ids(input_ids)
258258
)
259259
residual = None
260260
else:
@@ -423,8 +423,8 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
423423
logits = self.logits_processor(self.lm_head, hidden_states)
424424
return logits
425425

426-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
427-
return self.model.get_input_embeddings(input_ids)
426+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
427+
return self.model.embed_input_ids(input_ids)
428428

429429
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
430430
"""Load weights into the model (delegates to inner model and handles

vllm/model_executor/models/arctic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
442442
["hidden_states"], config.hidden_size
443443
)
444444

445-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
445+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
446446
return self.embed_tokens(input_ids)
447447

448448
def forward(
@@ -456,7 +456,7 @@ def forward(
456456
if inputs_embeds is not None:
457457
hidden_states = inputs_embeds
458458
else:
459-
hidden_states = self.get_input_embeddings(input_ids)
459+
hidden_states = self.embed_input_ids(input_ids)
460460
else:
461461
assert intermediate_tensors is not None
462462
hidden_states = intermediate_tensors["hidden_states"]
@@ -496,8 +496,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
496496
self.model.make_empty_intermediate_tensors
497497
)
498498

499-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
500-
return self.model.get_input_embeddings(input_ids)
499+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
500+
return self.model.embed_input_ids(input_ids)
501501

502502
def forward(
503503
self,

vllm/model_executor/models/aria.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def _process_image_input(
613613
def get_language_model(self) -> torch.nn.Module:
614614
return self.language_model
615615

616-
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
616+
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
617617
image_input = self._parse_and_validate_image_input(**kwargs)
618618
if image_input is None:
619619
return []
@@ -629,8 +629,8 @@ def forward(
629629
**kwargs: object,
630630
) -> torch.Tensor | IntermediateTensors:
631631
if inputs_embeds is None:
632-
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
633-
inputs_embeds = self.get_input_embeddings(
632+
multimodal_embeddings = self.embed_multimodal(**kwargs)
633+
inputs_embeds = self.embed_input_ids(
634634
input_ids,
635635
multimodal_embeddings,
636636
is_multimodal=input_ids == self.config.image_token_index,

vllm/model_executor/models/aya_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def _parse_and_validate_image_input(
417417
def get_language_model(self) -> torch.nn.Module:
418418
return self.language_model
419419

420-
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
420+
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
421421
image_input = self._parse_and_validate_image_input(**kwargs)
422422
if image_input is None:
423423
return []

vllm/model_executor/models/baichuan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(
309309
["hidden_states", "residual"], config.hidden_size
310310
)
311311

312-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
312+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
313313
return self.embed_tokens(input_ids)
314314

315315
def forward(
@@ -323,7 +323,7 @@ def forward(
323323
if inputs_embeds is not None:
324324
hidden_states = inputs_embeds
325325
else:
326-
hidden_states = self.get_input_embeddings(input_ids)
326+
hidden_states = self.embed_input_ids(input_ids)
327327
residual = None
328328
else:
329329
assert intermediate_tensors is not None
@@ -426,8 +426,8 @@ def __init__(
426426
self.model.make_empty_intermediate_tensors
427427
)
428428

429-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
430-
return self.model.get_input_embeddings(input_ids)
429+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
430+
return self.model.embed_input_ids(input_ids)
431431

432432
def forward(
433433
self,

vllm/model_executor/models/bailing_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def __init__(
438438
else:
439439
self.norm = PPMissingLayer()
440440

441-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
441+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
442442
return self.word_embeddings(input_ids)
443443

444444
def forward(
@@ -452,7 +452,7 @@ def forward(
452452
if inputs_embeds is not None:
453453
hidden_states = inputs_embeds
454454
else:
455-
hidden_states = self.get_input_embeddings(input_ids)
455+
hidden_states = self.embed_input_ids(input_ids)
456456
residual = None
457457
else:
458458
assert intermediate_tensors is not None
@@ -608,8 +608,8 @@ def __init__(
608608
self.model.make_empty_intermediate_tensors
609609
)
610610

611-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
612-
return self.model.get_input_embeddings(input_ids)
611+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
612+
return self.model.embed_input_ids(input_ids)
613613

614614
def forward(
615615
self,

vllm/model_executor/models/bamba.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def get_layer(prefix: str):
314314

315315
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
316316

317-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
317+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
318318
return self.embed_tokens(input_ids)
319319

320320
def forward(
@@ -328,7 +328,7 @@ def forward(
328328
if inputs_embeds is not None:
329329
hidden_states = inputs_embeds
330330
else:
331-
hidden_states = self.get_input_embeddings(input_ids)
331+
hidden_states = self.embed_input_ids(input_ids)
332332
residual = None
333333
else:
334334
assert intermediate_tensors is not None
@@ -493,8 +493,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
493493
self.model.make_empty_intermediate_tensors
494494
)
495495

496-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
497-
return self.model.get_input_embeddings(input_ids)
496+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
497+
return self.model.embed_input_ids(input_ids)
498498

499499
def forward(
500500
self,

0 commit comments

Comments
 (0)