25
25
import torch .nn as nn
26
26
import torch .nn .functional as F
27
27
from PIL import Image
28
- from safetensors .torch import save_file
29
- from safetensors .torch import load_model
28
+ from safetensors .torch import load_model , save_file
30
29
from transformers import CLIPImageProcessor
31
30
32
31
from ..runtime .session import Session
33
32
34
33
35
34
def add_multimodal_arguments (parser ):
36
- parser .add_argument ('--model_type' ,
37
- type = str ,
38
- default = None ,
39
- choices = [
40
- 'blip2' , 'llava' , 'llava_next' , 'llava_onevision' ,
41
- 'llava_onevision_lmms ' , 'vila ' , 'nougat ' , 'cogvlm ' ,
42
- 'fuyu ' , 'pix2struct ' , 'neva ' , 'kosmos-2 ' ,
43
- 'video-neva ' , 'phi-3-vision ' , 'phi-4-multimodal ' ,
44
- 'mllama' , 'internvl' , 'qwen2_vl' ,
45
- 'internlm-xcomposer2' , 'qwen2_audio' , 'pixtral' , 'eclair'
46
- ],
47
- help = "Model type" )
35
+ parser .add_argument (
36
+ '--model_type' ,
37
+ type = str ,
38
+ default = None ,
39
+ choices = [
40
+ 'blip2 ' , 'llava ' , 'llava_next ' , 'llava_onevision ' ,
41
+ 'llava_onevision_lmms' , 'vila ' , 'nougat ' , 'cogvlm ' , 'fuyu ' ,
42
+ 'pix2struct' , 'neva' , 'kosmos-2 ' , 'video-neva ' , 'phi-3-vision ' ,
43
+ 'phi-4-multimodal' , 'mllama' , 'internvl' , 'qwen2_vl' ,
44
+ 'internlm-xcomposer2' , 'qwen2_audio' , 'pixtral' , 'eclair'
45
+ ],
46
+ help = "Model type" )
48
47
parser .add_argument (
49
48
'--model_path' ,
50
49
type = str ,
@@ -1743,20 +1742,33 @@ def forward(self, pixel_values, attention_mask):
1743
1742
engine_name = f"model.engine" ,
1744
1743
dtype = torch .bfloat16 )
1745
1744
1745
+
1746
1746
def build_eclair_engine (args ):
1747
-
1747
+
1748
1748
class RadioWithNeck (torch .nn .Module ):
1749
+
1749
1750
def __init__ (self ):
1750
1751
super ().__init__ ()
1751
1752
1752
- self .model_encoder = torch .hub .load ("NVlabs/RADIO" , "radio_model" , version = "radio_v2.5-h" )
1753
+ self .model_encoder = torch .hub .load ("NVlabs/RADIO" ,
1754
+ "radio_model" ,
1755
+ version = "radio_v2.5-h" )
1753
1756
self .model_encoder .summary_idxs = torch .tensor (4 )
1754
1757
1755
1758
self .conv1 = torch .nn .Conv1d (1280 , 1024 , 1 )
1756
- self .layer_norm1 = torch .nn .LayerNorm (1024 , eps = 1e-6 , elementwise_affine = True )
1757
- self .conv2 = torch .nn .Conv2d (1024 , 1024 , kernel_size = (1 , 4 ), stride = (1 , 4 ), padding = 0 , bias = False )
1758
- self .layer_norm2 = torch .nn .LayerNorm (1024 , eps = 1e-6 , elementwise_affine = True )
1759
-
1759
+ self .layer_norm1 = torch .nn .LayerNorm (1024 ,
1760
+ eps = 1e-6 ,
1761
+ elementwise_affine = True )
1762
+ self .conv2 = torch .nn .Conv2d (1024 ,
1763
+ 1024 ,
1764
+ kernel_size = (1 , 4 ),
1765
+ stride = (1 , 4 ),
1766
+ padding = 0 ,
1767
+ bias = False )
1768
+ self .layer_norm2 = torch .nn .LayerNorm (1024 ,
1769
+ eps = 1e-6 ,
1770
+ elementwise_affine = True )
1771
+
1760
1772
@torch .no_grad
1761
1773
def forward (self , pixel_values ):
1762
1774
_ , feature = self .model_encoder (pixel_values )
@@ -1770,26 +1782,29 @@ def forward(self, pixel_values):
1770
1782
output = output .flatten (- 2 , - 1 ).permute (0 , 2 , 1 )
1771
1783
output = self .layer_norm2 (output )
1772
1784
return output
1773
-
1785
+
1774
1786
processor = NougatProcessor .from_pretrained (args .model_path )
1775
1787
model = VisionEncoderDecoderModel .from_pretrained ("facebook/nougat-base" )
1776
1788
model .encoder = RadioWithNeck ()
1777
1789
model .decoder .resize_token_embeddings (len (processor .tokenizer ))
1778
- model .config .decoder_start_token_id = processor .tokenizer .eos_token_id # 2
1790
+ model .config .decoder_start_token_id = processor .tokenizer .eos_token_id # 2
1779
1791
model .config .pad_token_id = processor .tokenizer .pad_token_id # 1
1780
1792
load_model (model , os .path .join (args .model_path , "model.safetensors" ))
1781
-
1793
+
1782
1794
wrapper = model .encoder .to (args .device )
1783
1795
# temporary fix due to TRT onnx export bug
1784
1796
for block in wrapper .model_encoder .model .blocks :
1785
1797
block .attn .fused_attn = False
1786
-
1787
- image = torch .randn ((1 , 3 , 2048 , 1648 ), device = args .device , dtype = torch .float16 )
1798
+
1799
+ image = torch .randn ((1 , 3 , 2048 , 1648 ),
1800
+ device = args .device ,
1801
+ dtype = torch .float16 )
1788
1802
export_onnx (wrapper , image , f'{ args .output_dir } /onnx' )
1789
1803
build_trt_engine (
1790
1804
args .model_type ,
1791
1805
[image .shape [1 ], image .shape [2 ], image .shape [3 ]], # [3, H, W]
1792
1806
f'{ args .output_dir } /onnx' ,
1793
1807
args .output_dir ,
1794
1808
args .max_batch_size ,
1795
- dtype = torch .bfloat16 ,engine_name = 'visual_encoder.engine' )
1809
+ dtype = torch .bfloat16 ,
1810
+ engine_name = 'visual_encoder.engine' )
0 commit comments