@@ -133,6 +133,19 @@ def __init__(
133
133
self .output_dir = "."
134
134
self ._saved_pte_filename = None
135
135
136
+ def __post_init__ (self ):
137
+ """
138
+ Post init function to update metadata based on dynamic shape
139
+ """
140
+ dynamic_shape = self ._get_dynamic_shape ()
141
+ if dynamic_shape is not None :
142
+ token_dim = dynamic_shape [0 ][1 ]
143
+ if self .verbose :
144
+ logging .info (
145
+ f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: { token_dim .max } "
146
+ )
147
+ self .metadata ["get_max_seq_len" ] = token_dim .max
148
+
136
149
def set_output_dir (self , output_dir : str ) -> "LLMEdgeManager" :
137
150
"""
138
151
Set the directory where the .pte file will be saved.
@@ -180,14 +193,19 @@ def _get_dynamic_shape(self) -> Any:
180
193
if self .dynamic_shapes :
181
194
return self .dynamic_shapes
182
195
183
- dim = torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )
184
196
if self .enable_dynamic_shape :
185
197
if not self .use_kv_cache :
186
198
# Only one input argument: tokens
187
- self .dynamic_shapes = ({1 : dim },)
199
+ # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
200
+ self .dynamic_shapes = (
201
+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len - 1 )},
202
+ )
188
203
else :
189
204
# Two input arguments: tokens and input_pos but input_pos is static shape
190
- self .dynamic_shapes = ({1 : dim }, {"input_pos" : {0 : 1 }})
205
+ self .dynamic_shapes = (
206
+ {1 : torch .export .Dim ("token_dim" , max = self .max_seq_len )},
207
+ {"input_pos" : {0 : 1 }},
208
+ )
191
209
else :
192
210
# Two input arguments: tokens and input_pos but both are of static shape
193
211
self .dynamic_shapes = None
0 commit comments