Skip to content

Commit 57e0765

Browse files
authored
[llm] Update metadata max_seq_len based on the max range of dynamic shapes
Differential Revision: D76530379 Pull Request resolved: #11611
1 parent 5960a4b commit 57e0765

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

extension/llm/export/builder.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ def __init__(
133133
self.output_dir = "."
134134
self._saved_pte_filename = None
135135

136+
def __post_init__(self):
137+
"""
138+
Post init function to update metadata based on dynamic shape
139+
"""
140+
dynamic_shape = self._get_dynamic_shape()
141+
if dynamic_shape is not None:
142+
token_dim = dynamic_shape[0][1]
143+
if self.verbose:
144+
logging.info(
145+
f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}"
146+
)
147+
self.metadata["get_max_seq_len"] = token_dim.max
148+
136149
def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
137150
"""
138151
Set the directory where the .pte file will be saved.
@@ -180,14 +193,19 @@ def _get_dynamic_shape(self) -> Any:
180193
if self.dynamic_shapes:
181194
return self.dynamic_shapes
182195

183-
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
184196
if self.enable_dynamic_shape:
185197
if not self.use_kv_cache:
186198
# Only one input argument: tokens
187-
self.dynamic_shapes = ({1: dim},)
199+
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
200+
self.dynamic_shapes = (
201+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
202+
)
188203
else:
189204
# Two input arguments: tokens and input_pos but input_pos is static shape
190-
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
205+
self.dynamic_shapes = (
206+
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
207+
{"input_pos": {0: 1}},
208+
)
191209
else:
192210
# Two input arguments: tokens and input_pos but both are of static shape
193211
self.dynamic_shapes = None

extension/llm/export/test/test_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
91+
self.assertEqual(result[0][1].max, self.max_seq_len)
9292

9393
# Check second element (input_pos dimension)
9494
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)