Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 160 additions & 15 deletions examples/models/core/enc_dec/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from helper import (convert_weight_to_dtype, fairseq_sin_pos_embedding,
fuse_qkv_one_layer, reshape, split)
from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration,
MBartForConditionalGeneration,
MBartForConditionalGeneration, NougatProcessor,
Pix2StructForConditionalGeneration,
T5ForConditionalGeneration, VisionEncoderDecoderModel)

from tensorrt_llm._utils import pad_vocab_size
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType)
from tensorrt_llm.layers import LanguageAdapterConfig
Expand All @@ -30,6 +31,9 @@
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
mlp_type_map = {i.name: i.value for i in MLPType}

# Constants for specific model configurations
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS = 20000


def copy_args_to_component_config(component_config, args):
for arg in vars(args):
Expand Down Expand Up @@ -619,14 +623,19 @@ def parse_bart_config(args, hf_model):
config = configparser.ConfigParser()

config['decoder'] = dict()
for key, val in hf_model.model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
if args.eclair_radio:
for key, val in hf_model.config.to_dict().items():
config["decoder"][key] = f"{val}"
else:
for key, val in hf_model.model.decoder.config.to_dict().items():
config["decoder"][key] = f"{val}"
config["decoder"]["q_scaling"] = '1'
config["decoder"]["rescale_before_lm_head"] = str(False)
config['decoder']['has_model_final_layernorm'] = str(
args.nougat or isinstance(hf_model, MBartForConditionalGeneration))
args.nougat or args.eclair_radio
or isinstance(hf_model, MBartForConditionalGeneration))

if args.nougat:
if args.nougat or args.eclair_radio:
# These flags are true for mbart decoders, but missing in HF config
config['decoder']['normalize_before'] = str(True)
config['decoder']['normalize_embeddings'] = str(True)
Expand Down Expand Up @@ -763,10 +772,14 @@ def parse_bart_config_by_component(config, component, args):
return component_config

encoder_config = None
if not args.nougat:
if not (args.nougat or args.eclair_radio):
encoder_config = parse_bart_config_by_component(config, "encoder", args)
decoder_config = parse_bart_config_by_component(config, "decoder", args)

# Override n_positions for eclair_radio model
if args.eclair_radio:
decoder_config.n_positions = ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS

return encoder_config, decoder_config


Expand Down Expand Up @@ -952,11 +965,22 @@ def get_attn_module_name(component, layer, attn_type):
(hidden_size * 3 // mapping.tp_size)))

if component == 'decoder':
import torch
lm_head_weights = params['lm_head.weight'].clone().detach()
vocab_size = config.vocab_size
if params['lm_head.weight'].shape[0] % mapping.tp_size != 0:
vocab_size_padded = pad_vocab_size(config.vocab_size,
mapping.tp_size)
pad_width = vocab_size_padded - config.vocab_size

lm_head_weights = torch.nn.functional.pad(lm_head_weights,
(0, 0, 0, pad_width),
'constant',
value=0)
vocab_size = vocab_size_padded
weights['lm_head.weight'] = reshape(
split(params['lm_head.weight'],
mapping.tp_size,
mapping.tp_rank,
dim=0), (config.vocab_size // mapping.tp_size, hidden_size))
split(lm_head_weights, mapping.tp_size, mapping.tp_rank, dim=0),
(vocab_size // mapping.tp_size, hidden_size))

if config.has_model_final_layernorm:
weights['transformer.ln_f.weight'] = params[
Expand Down Expand Up @@ -1479,6 +1503,113 @@ def get_model(args):
if args.nougat:
model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
model = model.get_decoder()
elif args.eclair_radio:
import torch

class RadioWithNeck(torch.nn.Module):

def __init__(self):
super().__init__()

self.model_encoder = torch.hub.load("NVlabs/RADIO",
"radio_model",
version="radio_v2.5-h")
self.model_encoder.summary_idxs = torch.tensor(4)

self.conv1 = torch.nn.Conv1d(1280, 1024, 1)
self.layer_norm1 = torch.nn.LayerNorm(
1024, eps=1e-6, elementwise_affine=True)
self.conv2 = torch.nn.Conv2d(1024,
1024,
kernel_size=(1, 4),
stride=(1, 4),
padding=0,
bias=False)
self.layer_norm2 = torch.nn.LayerNorm(
1024, eps=1e-6, elementwise_affine=True)

def forward(self, pixel_values):
_, feature = self.model_encoder(pixel_values)
output = self.conv1(feature.permute(0, 2,
1)).permute(0, 2, 1)
output = self.layer_norm1(output).permute(0, 2, 1)

b, d, _ = output.shape
h = pixel_values.shape[-2] // 16
w = pixel_values.shape[-1] // 16
output = self.conv2(output.reshape(b, d, h, w))
output = output.flatten(-2, -1).permute(0, 2, 1)
output = self.layer_norm2(output)
return output

def get_processor():
processor = NougatProcessor.from_pretrained(
"facebook/nougat-base")

special_tokens = {
"output_plain_index": "<output_plain>",
"output_markdown_index": "<output_markdown>",
"output_no_text_index": "<output_no_text>",
"output_ocr_index": "<output_ocr>",
"predict_bbox_index": "<predict_bbox>",
"no_bbox_index": "<no_bbox>",
"bbox_start_index": "<bbox>", # not used but can keep
# "bbox_end_index": "</bbox>", # not used but can keep
"no_class_index": "<no_classes>",
"predict_classes_index": "<predict_classes>",
}
for key, special_t in special_tokens.items():
processor.tokenizer.add_special_tokens(
{"additional_special_tokens": [special_t]})
setattr(processor.tokenizer, key,
processor.tokenizer.encode(special_t)[1])

# Add regular tokens for boxes
processor.tokenizer.add_tokens(
[f"<x_{x_i}>" for x_i in range(1024)])
processor.tokenizer.add_tokens(
[f"<y_{y_i}>" for y_i in range(1280)])
# Add regular tokens for classes
#"<class_{class_i}>"
possible_classes = [
"Text", "Title", "Section-header", "List-item", "TOC",
"Bibliography", "Footnote", "Page-header", "Page-footer",
"Picture", "Formula", "Page-number", "Table", "Caption"
]
processor.tokenizer.add_tokens(
[f"<class_{cls}>" for cls in possible_classes])
return processor

processor = get_processor()
model = VisionEncoderDecoderModel.from_pretrained(
"facebook/nougat-base")
model.encoder = RadioWithNeck()
model.decoder.resize_token_embeddings(len(processor.tokenizer),
pad_to_multiple_of=64)
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id # 2
model.config.pad_token_id = processor.tokenizer.pad_token_id # 1
from transformers.models.mbart.modeling_mbart import \
MBartLearnedPositionalEmbedding
_, d_model = model.device, model.config.decoder.d_model

with torch.inference_mode():
# Inspect checkpoint shapes
safetensors.torch.load_model(model,
os.path.join(
args.model_dir,
"model.safetensors"),
strict=False)
model.decoder.model.decoder.embed_positions = MBartLearnedPositionalEmbedding(
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS, d_model)
model.decoder.model.decoder.embed_positions.weight.data.zero_()
model.decoder.model.decoder.embed_positions.weight.requires_grad_(
True)
model.decoder.lm_head.weight = model.decoder.get_input_embeddings(
).weight

model.eval()
model = model.get_decoder()

else:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir)
elif args.model_type == "pix2struct":
Expand Down Expand Up @@ -1522,14 +1653,23 @@ def convert_checkpoint(args):
quant_algo = None

model_type = args.model_type if args.model_type != "blip2" else "t5"
encoder_config, decoder_config = globals()[f'parse_{model_type}_config'](
args, model)
parse_config_mapper = {
't5': parse_t5_config,
'pix2struct': parse_pix2struct_config,
'blip2': parse_t5_config, # blip2 uses t5 config parser
'language_adapter': parse_language_adapter_config,
'nmt': parse_nmt_config,
'bart': parse_bart_config,
}
encoder_config, decoder_config = parse_config_mapper[model_type](args,
model)

additional_settings = ["gated_act"]
if model_type == 'language_adapter':
additional_settings += ["residual_scaling", "language_adapter_config"]

if not args.nougat and args.model_type != "pix2struct":
if not (args.nougat
or args.eclair_radio) and args.model_type != "pix2struct":
tllm_encoder_config = {
'architecture': "EncoderModel",
'dtype': args.dtype,
Expand Down Expand Up @@ -1664,7 +1804,8 @@ def convert_checkpoint(args):
decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding

if args.workers == 1:
if not args.nougat and args.model_type != "pix2struct":
if not (args.nougat
or args.eclair_radio) and args.model_type != "pix2struct":
convert(0, world_size, args, tllm_encoder_config,
encoder_convert_args, encoder_saved_dir)
convert(0, world_size, args, tllm_decoder_config, decoder_convert_args,
Expand All @@ -1674,7 +1815,8 @@ def convert_checkpoint(args):
args.workers = world_size
LOGGER.info(f'Convert checkpoint using {args.workers} workers.')
import torch.multiprocessing as mp
if not args.nougat and args.model_type != "pix2struct":
if not (args.nougat
or args.eclair_radio) and args.model_type != "pix2struct":
mp.spawn(convert,
nprocs=args.workers,
args=(world_size, args, tllm_encoder_config,
Expand Down Expand Up @@ -1736,6 +1878,9 @@ def convert(worker_rank, world_size, args, model_config, convert_args,
parser.add_argument("--nougat",
action="store_true",
help="Model which uses vision encoder + mbart decoder")
parser.add_argument("--eclair_radio",
action="store_true",
help="Model which uses vision encoder + mbart decoder")
parser.add_argument("--verbose",
action="store_true",
help="Provide verbose messages")
Expand Down
1 change: 1 addition & 0 deletions examples/models/core/multimodal/requirements-eclair.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
timm
8 changes: 5 additions & 3 deletions tensorrt_llm/models/enc_dec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch

from tensorrt_llm._common import default_net
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
from tensorrt_llm._utils import (numpy_to_torch, pad_vocab_size,
str_dtype_to_torch)
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType, PositionEmbeddingType, Tensor,
assertion, cast, gather_last_token_logits,
Expand Down Expand Up @@ -1156,9 +1157,11 @@ def __init__(self, config: PretrainedConfig):
self.transformer.assign_module(decoder_layers, "layers")

if self.mapping.is_last_pp_rank():
vocab_size_padded = pad_vocab_size(self.config.vocab_size,
self.config.mapping.tp_size)
self.lm_head = ColumnLinear(
self.config.hidden_size,
self.config.vocab_size,
vocab_size_padded,
bias=False if not hasattr(self.config, "has_lm_head_bias") else
self.config.has_lm_head_bias,
dtype=self.config.dtype,
Expand Down Expand Up @@ -1208,7 +1211,6 @@ def check_config(self, config: PretrainedConfig):
config.set_if_not_exist('num_buckets', None)
config.set_if_not_exist('max_distance', None)
config.set_if_not_exist('relative_attention', False)
config.set_if_not_exist('residual_scaling', 1.0)

def forward(self,
decoder_input_ids: Tensor,
Expand Down
Loading