Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/contributing/model/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ The initialization code should look like this:

### Computation Code

- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
- Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.

```python
class MyModel(nn.Module):
...

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
...
```

Expand Down
6 changes: 3 additions & 3 deletions docs/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Further update the model as follows:

More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it.

- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.

??? code

Expand All @@ -49,7 +49,7 @@ Further update the model as follows:
image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(
def embed_multimodal(
self,
**kwargs: object,
) -> MultiModalEmbeddings | None:
Expand All @@ -69,7 +69,7 @@ Further update the model as follows:
!!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids].

You may override this method if additional logic is required for your model when merging embeddings.

Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def __init__(
["hidden_states", "residual"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand All @@ -396,7 +396,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
Expand Down Expand Up @@ -557,8 +557,8 @@ def _init_model(
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
["hidden_states", "residual"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand All @@ -254,7 +254,7 @@ def forward(
hidden_states = (
inputs_embeds
if inputs_embeds is not None
else self.get_input_embeddings(input_ids)
else self.embed_input_ids(input_ids)
)
residual = None
else:
Expand Down Expand Up @@ -423,8 +423,8 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights into the model (delegates to inner model and handles
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
["hidden_states"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand All @@ -456,7 +456,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
Expand Down Expand Up @@ -496,8 +496,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model.make_empty_intermediate_tensors
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def _process_image_input(
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
Expand All @@ -629,8 +629,8 @@ def forward(
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
multimodal_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _parse_and_validate_image_input(
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(
["hidden_states", "residual"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand All @@ -323,7 +323,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
Expand Down Expand Up @@ -426,8 +426,8 @@ def __init__(
self.model.make_empty_intermediate_tensors
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def __init__(
else:
self.norm = PPMissingLayer()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)

def forward(
Expand All @@ -452,7 +452,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
Expand Down Expand Up @@ -608,8 +608,8 @@ def __init__(
self.model.make_empty_intermediate_tensors
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def get_layer(prefix: str):

self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand All @@ -328,7 +328,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
Expand Down Expand Up @@ -493,8 +493,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model.make_empty_intermediate_tensors
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def __init__(
self.embeddings = embedding_class(self.config)
self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.word_embeddings(input_ids)

def forward(
Expand Down Expand Up @@ -486,8 +486,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.pooler = self._build_pooler(pooler_config)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def forward(
self,
Expand Down Expand Up @@ -835,8 +835,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
}
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.embed_input_ids(input_ids)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
Expand Down Expand Up @@ -893,8 +893,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
}
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.embed_input_ids(input_ids)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/bert_with_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __init__(
)
self.pooler = BertPooler(self.config) if add_pooling_layer else None

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)

def forward(
Expand Down Expand Up @@ -714,8 +714,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loaded_params = loader.load_weights(weights)
return loaded_params

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.embed_input_ids(input_ids)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def _process_image_input(self, image_input: Blip2ImageInputs) -> torch.Tensor:
def get_language_model(self) -> torch.nn.Module:
return self.language_model

def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
["hidden_states"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)

def forward(
Expand All @@ -285,7 +285,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else:
assert intermediate_tensors is not None
Expand Down Expand Up @@ -353,8 +353,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.transformer.make_empty_intermediate_tensors
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)

def forward(
self,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
["hidden_states", "residual"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
Expand All @@ -912,7 +912,7 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
Expand Down Expand Up @@ -998,15 +998,15 @@ def _parse_and_validate_image_input(
def get_language_model(self) -> torch.nn.Module:
return self.model

def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.dtype)
)
vision_embeddings = self.model.get_input_embeddings(image_tokens)
vision_embeddings = self.model.embed_input_ids(image_tokens)
return vision_embeddings

def forward(
Expand Down
Loading