Skip to content

feat: Mistral-Large-2 support in the Pytorch workflow #3845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
from .modeling_llama import LlamaForCausalLM
from .modeling_llava_next import LlavaNextModel
from .modeling_mistral import MistralForCausalLM
from .modeling_mixtral import MixtralForCausalLM
from .modeling_nemotron import NemotronForCausalLM
from .modeling_nemotron_h import NemotronHForCausalLM
Expand All @@ -23,6 +24,7 @@
"DeepseekV3ForCausalLM",
"LlamaForCausalLM",
"LlavaNextModel",
"MistralForCausalLM",
"MixtralForCausalLM",
"NemotronForCausalLM",
"NemotronHForCausalLM",
Expand All @@ -39,6 +41,7 @@

if transformers.__version__ >= "4.45.1":
from .modeling_mllama import MllamaForConditionalGeneration # noqa

__all__.append("MllamaForConditionalGeneration")
else:
print(
Expand Down
192 changes: 192 additions & 0 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Any, Optional, Tuple

import torch
from torch import nn
from transformers import MistralConfig

from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.attention_backend.interface import (
PositionalEmbeddingParams, RopeParams)
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
DecoderModelForCausalLM,
register_auto_model,
support_pp)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._torch.modules.linear import TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm.functional import PositionEmbeddingType


class MistralAttention(Attention):

def __init__(
self,
model_config: ModelConfig[MistralConfig],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=False,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
),
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)


class MistralDecoderLayer(DecoderLayer):

def __init__(
self,
model_config: ModelConfig[MistralConfig],
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx

self.self_attn = MistralAttention(
model_config,
layer_idx=layer_idx,
)

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
dtype=config.torch_dtype,
config=model_config,
)
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)

self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)

def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
return hidden_states, residual


@support_pp
class MistralModel(DecoderModel):

def __init__(self, model_config: ModelConfig[MistralConfig]):
super().__init__(model_config)
config = self.model_config.pretrained_config
self.padding_idx = config.pad_token_id

self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
MistralDecoderLayer(
model_config,
layer_idx,
) for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)

def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
lora_params: Optional[Any] = None,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

hidden_states = inputs_embeds

residual = None
for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
lora_params=lora_params,
)

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


@register_auto_model("MistralForCausalLM")
class MistralForCausalLM(DecoderModelForCausalLM[MistralModel, MistralConfig]):

def __init__(
self,
model_config: ModelConfig[MistralConfig],
):
super().__init__(
MistralModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)